Skip to content

Commit b46cb3d

Browse files
committed
Add all of @ASvyatkovskiy transformer work from late 2019
1 parent 9935ff7 commit b46cb3d

File tree

7 files changed

+1023
-0
lines changed

7 files changed

+1023
-0
lines changed

examples/conf.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ model:
5858
loss_scale_factor: 1.0
5959
use_batch_norm: false
6060
torch: False
61+
transformer: False # requires torch: True
6162
shallow: False
6263
shallow_model:
6364
num_samples: 1000000 # 1000000 # the number of samples to use for training

examples/slurm.cmd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,13 @@ rm /tigress/$USER/normalization/*
2727

2828
export OMPI_MCA_btl="tcp,self,vader"
2929
srun python mpi_learn.py
30+
31+
# single model replica PyTorch implementation of Transformer
32+
# (set one rank, one core, ...)
33+
#
34+
# conda activate Py3
35+
# module load cudnn/cuda-10.0/7.5.0
36+
# module load cudatoolkit/10.0
37+
# module load openmpi/gcc/3.1.3/64
38+
# export LD_LIBRARY_PATH=/usr/local/cuda-10.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH
39+
# srun python transformer_learn.py

examples/transformer_learn.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from plasma.models.loader import Loader
2+
from plasma.preprocessor.preprocess import guarantee_preprocessed
3+
from plasma.transformer.runner import train
4+
from plasma.models.torch_runner import make_predictions_and_evaluate_gpu
5+
from plasma.conf import conf
6+
7+
from pprint import pprint
8+
import numpy as np
9+
import datetime
10+
import logging
11+
import random
12+
import sys
13+
import os
14+
15+
import matplotlib
16+
matplotlib.use('Agg')
17+
18+
pprint(conf)
19+
20+
if conf['data']['normalizer'] == 'minmax':
21+
from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer
22+
elif conf['data']['normalizer'] == 'meanvar':
23+
from plasma.preprocessor.normalize import MeanVarNormalizer as Normalizer
24+
elif conf['data']['normalizer'] == 'var':
25+
# performs !much better than minmaxnormalizer
26+
from plasma.preprocessor.normalize import VarNormalizer as Normalizer
27+
elif conf['data']['normalizer'] == 'averagevar':
28+
# performs !much better than minmaxnormalizer
29+
from plasma.preprocessor.normalize import (
30+
AveragingVarNormalizer as Normalizer
31+
)
32+
else:
33+
print('unkown normalizer. exiting')
34+
exit(1)
35+
36+
if __name__ == '__main__':
37+
logging.basicConfig(
38+
level=logging.INFO,
39+
format="%(asctime)-15s %(name)-5s %(levelname)-8s %(message)s",
40+
)
41+
LOGGER = logging.getLogger("transformer_learn")
42+
43+
shot_list_dir = conf['paths']['shot_list_dir']
44+
shot_files = conf['paths']['shot_files']
45+
shot_files_test = conf['paths']['shot_files_test']
46+
train_frac = conf['training']['train_frac']
47+
stateful = conf['model']['stateful']
48+
49+
# FIXME change seed setting
50+
np.random.seed(0)
51+
random.seed(0)
52+
53+
only_predict = len(sys.argv) > 1
54+
custom_path = None
55+
if only_predict:
56+
custom_path = sys.argv[1]
57+
print("predicting using path {}".format(custom_path))
58+
59+
#####################################################
60+
# PREPROCESSING #
61+
#####################################################
62+
# TODO(KGF): check tuple unpack
63+
(shot_list_train, shot_list_validate,
64+
shot_list_test) = guarantee_preprocessed(conf)
65+
66+
#####################################################
67+
# NORMALIZATION #
68+
#####################################################
69+
70+
print("normalization", end='')
71+
nn = Normalizer(conf)
72+
nn.train()
73+
loader = Loader(conf, nn)
74+
print("...done")
75+
print('Training on {} shots, testing on {} shots'.format(
76+
len(shot_list_train), len(shot_list_test)))
77+
78+
79+
#####################################################
80+
# TRAINING #
81+
#####################################################
82+
train(conf, shot_list_train.random_sublist(512),
83+
shot_list_validate.random_sublist(256), loader)
84+
#if not only_predict:
85+
# p = old_mp.Process(target=train,
86+
# args=(conf, shot_list_train,
87+
# shot_list_validate, loader)
88+
# )
89+
# p.start()
90+
# p.join()
91+
92+
#####################################################
93+
# PREDICTING #
94+
#####################################################
95+
loader.set_inference_mode(True)
96+
97+
# load last model for testing
98+
print('saving results')
99+
y_prime = []
100+
y_prime_test = []
101+
y_prime_train = []
102+
103+
y_gold = []
104+
y_gold_test = []
105+
y_gold_train = []
106+
107+
disruptive = []
108+
disruptive_train = []
109+
disruptive_test = []
110+
111+
# y_prime_train, y_gold_train, disruptive_train =
112+
# make_predictions(conf, shot_list_train, loader)
113+
# y_prime_test, y_gold_test, disruptive_test =
114+
# make_predictions(conf, shot_list_test, loader)
115+
116+
# TODO(KGF): check tuple unpack
117+
(y_prime_train, y_gold_train, disruptive_train, roc_train,
118+
loss_train) = make_predictions_and_evaluate_gpu(
119+
conf, shot_list_train, loader, custom_path)
120+
(y_prime_test, y_gold_test, disruptive_test, roc_test,
121+
loss_test) = make_predictions_and_evaluate_gpu(
122+
conf, shot_list_test, loader, custom_path)
123+
print('=========Summary========')
124+
print('Train Loss: {:.3e}'.format(loss_train))
125+
print('Train ROC: {:.4f}'.format(roc_train))
126+
print('Test Loss: {:.3e}'.format(loss_test))
127+
print('Test ROC: {:.4f}'.format(roc_test))
128+
129+
130+
disruptive_train = np.array(disruptive_train)
131+
disruptive_test = np.array(disruptive_test)
132+
133+
y_gold = y_gold_train + y_gold_test
134+
y_prime = y_prime_train + y_prime_test
135+
disruptive = np.concatenate((disruptive_train, disruptive_test))
136+
137+
shot_list_validate.make_light()
138+
shot_list_test.make_light()
139+
shot_list_train.make_light()
140+
141+
save_str = 'results_' + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
142+
result_base_path = conf['paths']['results_prepath']
143+
if not os.path.exists(result_base_path):
144+
os.makedirs(result_base_path)
145+
np.savez(result_base_path+save_str, y_gold=y_gold, y_gold_train=y_gold_train,
146+
y_gold_test=y_gold_test, y_prime=y_prime, y_prime_train=y_prime_train,
147+
y_prime_test=y_prime_test, disruptive=disruptive,
148+
disruptive_train=disruptive_train, disruptive_test=disruptive_test,
149+
shot_list_validate=shot_list_validate,
150+
shot_list_train=shot_list_train, shot_list_test=shot_list_test,
151+
conf=conf)
152+
153+
print('finished.')

0 commit comments

Comments
 (0)