forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaverage_checkpoints.py
More file actions
executable file
·98 lines (85 loc) · 3.17 KB
/
average_checkpoints.py
File metadata and controls
executable file
·98 lines (85 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import numpy as np
def main():
if args.log is not None:
with open(args.log) as f:
logs = json.load(f)
val_scores = []
for log in logs:
if "validation/main/acc" in log.keys():
val_scores += [[log["epoch"], log["validation/main/acc"]]]
elif "val_perplexity" in log.keys():
val_scores += [[log["epoch"], 1 / log["val_perplexity"]]]
elif "validation/main/loss" in log.keys():
val_scores += [[log["epoch"], -log["validation/main/loss"]]]
if len(val_scores) == 0:
raise ValueError(
"`validation/main/acc` or `val_perplexity` is not found in log."
)
val_scores = np.array(val_scores)
sort_idx = np.argsort(val_scores[:, -1])
sorted_val_scores = val_scores[sort_idx][::-1]
print("best val scores = " + str(sorted_val_scores[: args.num, 1]))
print(
"selected epochs = "
+ str(sorted_val_scores[: args.num, 0].astype(np.int64))
)
last = [
os.path.dirname(args.snapshots[0]) + "/snapshot.ep.%d" % (int(epoch))
for epoch in sorted_val_scores[: args.num, 0]
]
else:
last = sorted(args.snapshots, key=os.path.getmtime)
last = last[-args.num :]
print("average over", last)
avg = None
if args.backend == "pytorch":
import torch
# sum
for path in last:
states = torch.load(path, map_location=torch.device("cpu"))["model"]
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
avg[k] /= args.num
torch.save(avg, args.out)
elif args.backend == "chainer":
# sum
for path in last:
states = np.load(path)
if avg is None:
keys = [x.split("main/")[1] for x in states if "model" in x]
avg = dict()
for k in keys:
avg[k] = states["updater/model:main/{}".format(k)]
else:
for k in keys:
avg[k] += states["updater/model:main/{}".format(k)]
# average
for k in keys:
if avg[k] is not None:
avg[k] /= args.num
np.savez_compressed(args.out, **avg)
os.rename("{}.npz".format(args.out), args.out) # numpy save with .npz extension
else:
raise ValueError("Incorrect type of backend")
def get_parser():
parser = argparse.ArgumentParser(description="average models from snapshot")
parser.add_argument("--snapshots", required=True, type=str, nargs="+")
parser.add_argument("--out", required=True, type=str)
parser.add_argument("--num", default=10, type=int)
parser.add_argument("--backend", default="chainer", type=str)
parser.add_argument("--log", default=None, type=str, nargs="?")
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
main()