-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmetric_train.py
More file actions
118 lines (103 loc) · 4.22 KB
/
metric_train.py
File metadata and controls
118 lines (103 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import sys
import os
import torch
import argparse
import numpy as np
from pathlib import Path
from DataSets import create_datasets, create_dataloader
from Utils.eval import eval_metric_model
from Utils.tools import analysis_dataset
from Utils.ddp_tools import init_env, save_model, save_criterion, copy_model
from Models.Backbone import create_backbone
from Models.Loss import create_metric_loss
from Models.Optimizer import create_optimizer
from Models.Scheduler import create_scheduler
from torchinfo import summary
from pytorch_metric_learning import miners
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
import colossalai
cur_path = os.path.abspath(os.path.dirname(__file__))
if __name__ == "__main__":
"""度量学习"""
parser = colossalai.get_default_parser()
parser.add_argument("--config_file", help="训练配置", default="./Config/config.py")
# 初始化环境
ckpt_path, cfg, tb_writer, logger = init_env(parser.parse_args().config_file)
# 数据集
dataset = analysis_dataset(cfg.Txt)
tb_writer.add_dataset_info(dataset)
train_set = create_datasets(
dataset=dataset["train"], size=cfg.Size, process=cfg.Process, use_augment=True
)
train_dataloader = create_dataloader(cfg.Batch, train_set, cfg.Sampler)
# 模型
if os.path.isfile(cfg.Backbone):
model = torch.load(cfg.Backbone, map_location="cpu")
else:
model = create_backbone(cfg.Backbone, cfg.Feature_dim, metric=True)
model.info = {"task": "metric", "all_labels": dataset["all_labels"]} # 额外信息
cp_model = copy_model(model)
tb_writer.add_model_info(model, cfg.Size)
# 损失函数/分类器
mining_func = miners.MultiSimilarityMiner() # 难样例挖掘
if os.path.isfile(cfg.Loss):
criterion = torch.load(cfg.Loss)
else:
criterion = create_metric_loss(
cfg.Loss, cfg.Feature_dim, len(dataset["all_labels"])
)
# # 分类器:基于类中心,加载权重
# classcenter = np.load("home/xxx.npy")
# criterion.W.data = torch.from_numpy(classcenter.T) # 类中心转置
# 优化器
params = [
{"params": model.parameters(), "lr": cfg.LR},
{"params": criterion.parameters(), "lr": cfg.LR},
]
optimizer = create_optimizer(cfg.Optimizer, params, lr=cfg.LR)
# 学习率调度器
lr_scheduler = create_scheduler(cfg.Scheduler, cfg.Epochs, optimizer)
# colossalai封装
engine, train_dataloader, _, _ = colossalai.initialize(
model,
optimizer,
criterion,
train_dataloader,
)
best_score = 0.0
for epoch in range(cfg.Epochs):
engine.train()
logger.info(f"Starting {epoch} / {cfg.Epochs}", ranks=[0])
for batch_idx, (imgs, labels) in enumerate(train_dataloader):
imgs, labels = imgs.cuda(), labels.cuda()
engine.zero_grad()
output = engine(imgs)
hard_tuples = mining_func(output, labels)
loss = engine.criterion(output, labels, hard_tuples)
engine.backward(loss)
engine.step()
if batch_idx % 100 == 0:
iter_num = int(batch_idx + epoch * len(train_dataloader))
tb_writer.add_scalar("Train/loss", loss.item(), iter_num)
# 验证集评估
engine.eval()
score = eval_metric_model(
engine, dataset, cfg.Size, cfg.Process, cfg.Batch, mode="val"
)
if best_score <= score["value"]:
best_score = score["value"]
save_model(
engine.model, cp_model, ckpt_path + Path(cfg.Backbone).stem + "_best.pt"
)
save_criterion(
engine.criterion, ckpt_path + Path(cfg.Loss).stem + "_best.pt"
)
# 可视化
tb_writer.add_augment_imgs(epoch, imgs, labels, dataset["all_labels"])
tb_writer.add_scalar("Train/lr", lr_scheduler.get_last_lr()[0], epoch)
tb_writer.add_scalar("Val/" + score["index"], score["value"], epoch)
lr_scheduler.step()
save_model(engine.model, cp_model, ckpt_path + Path(cfg.Backbone).stem + "_last.pt")
save_criterion(engine.criterion, ckpt_path + Path(cfg.Loss).stem + "_last.pt")
tb_writer.close()