-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_single_gpu.py
More file actions
110 lines (86 loc) · 4.72 KB
/
train_single_gpu.py
File metadata and controls
110 lines (86 loc) · 4.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import os
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
import torch.nn as nn
from dataset import MouseDataset
from model import MouseModel, VisionModel, AudioModel
from tqdm import tqdm
import argparse
from torch.utils.tensorboard import SummaryWriter
def parse_args():
parser = argparse.ArgumentParser()
# data path
parser.add_argument('--video_path', default="/data/zhaozhenghao/Projects/Mouse_behavior/dataset/Formalin/frame_folder", type=str, dest='video_path', help='Video path.')
parser.add_argument('--pred_path', default='/data/zhaozhenghao/Projects/Mouse_behavior/track_result/Formalin/sideview_pose_ckpt1/alphapose-results.json', type=str, dest='pred_path', help='Prediction path.')
parser.add_argument('--label_path', default='/data/zhaozhenghao/Projects/Mouse_behavior/dataset/Formalin/Formalin_acute_pain_1.csv', type=str, dest='label_path', help='Label path.')
parser.add_argument('--audio_path', default='/data/zhaozhenghao/Projects/Mouse_behavior/dataset/Formalin/Formalin_Ultrasound_recording.wav', type=str, dest='audio_path', help='Audio path.')
# hyperparameters
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
# parameters
parser.add_argument('--resampling_rate', type=int, default=1500, help='Resampling rate for audio.')
parser.add_argument('--step', type=int, default=10, help='Step size for sliding window.')
parser.add_argument('--stride', type=int, default=10, help='Stride size for sliding window.')
# tensorboard
parser.add_argument('--tensorboard', type=bool, default=True)
parser.add_argument('--log_dir', type=str, default='./logs')
# adapt for ddp
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
return args
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.tensorboard:
writer = SummaryWriter(args.log_dir)
# dataset
dataset = MouseDataset(args.video_path, args.pred_path, args.label_path, args.audio_path, args)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# model
model = VisionModel(num_meta_features=12, num_classes=3).to(device)
# model = AudioModel(num_audio_features=50, num_classes=3).to(device)
# model = MouseModel(num_meta_features=12, num_audio_features=50, num_classes=3).to(device)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=args.learning_rate)
for epoch in range(args.num_epochs):
model.train()
for steps, (images, behavior_feat, audio, labels) in enumerate(tqdm(train_loader)):
images, behavior_feat, audio, labels = images.to(device), behavior_feat.to(device), audio.to(device), labels.to(device)
outputs = model(images, behavior_feat)
# outputs = model(audio)
# outputs = model(images, behavior_feat, audio)
loss = criterion(outputs, labels.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if args.tensorboard:
writer.add_scalar('Train/Loss', loss.item(), epoch * len(train_loader) + steps)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, behavior_feat, audio, labels in test_loader:
images, behavior_feat, audio, labels = images.to(device), behavior_feat.to(device), audio.to(device), labels.to(device)
outputs = model(images, behavior_feat)
# outputs = model(audio)
# outputs = model(images, behavior_feat, audio)
probs = torch.sigmoid(outputs)
preds = (probs > 0.5).float()
total += labels.size(0)
correct += (preds == labels).all(dim=1).sum().item()
accuracy = 100 * correct / total
if args.tensorboard:
writer.add_scalar('Test/Accuracy', accuracy, epoch)
print(f"Epoch [{epoch+1}/{args.num_epochs}], Loss: {loss.item():.4f}, Test Accuracy: {accuracy:.2f}%")
print("Training finished!")
if args.tensorboard:
writer.close()
if __name__ == "__main__":
args = parse_args()
main(args)