-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathablation_mask.py
More file actions
32 lines (25 loc) · 931 Bytes
/
ablation_mask.py
File metadata and controls
32 lines (25 loc) · 931 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import argparse
r"""
python -m torch.distributed.launch --nproc_per_node=8 --master_port=59566 --use_env train.py \
models/iter_mask/piclick_base448_cocolvis_itermask_5m.py \
--batch-size=136 \
--ngpus=8
"""
def get_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('--ngpu', default=8, type=int)
return parser.parse_args()
def main():
args = get_args()
ngpu = args.ngpu
batch_size = 17 * ngpu
for i in [5, 3, 1, 6, 4, 2, ]:
cmd = f"python -m torch.distributed.launch --nproc_per_node={ngpu} --master_port=59566 --use_env train.py " \
f"models/iter_mask/piclick_base448_cocolvis_itermask_{i}m.py " \
f"--batch-size={batch_size} " \
f"--ngpus={ngpu} " \
f" | tee logs/piclick_base448_cocolvis_itermask_{i}m.log"
os.system(cmd)
if __name__ == '__main__':
main()