Skip to content

Commit 8aa5c1a

Browse files
author
Julian Kates-Harbeck
committed
implemented more general hash that allows deep copies of the conf to maintain the hash
1 parent 34f66b6 commit 8aa5c1a

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

plasma/models/builder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os,sys
2020
import numpy as np
2121
from copy import deepcopy
22-
from plasma.utils.downloading import makedirs_process_safe
22+
from plasma.utils.downloading import makedirs_process_safe,general_object_hash
2323

2424
import hashlib
2525

@@ -31,18 +31,22 @@ def on_batch_end(self, batch, logs=None):
3131
self.losses.append(logs.get('loss'))
3232

3333

34+
3435
class ModelBuilder(object):
3536
def __init__(self,conf):
3637
self.conf = conf
3738

3839
def get_unique_id(self):
39-
num_epochs = self.conf['training']['num_epochs']
40+
#num_epochs = self.conf['training']['num_epochs']
4041
this_conf = deepcopy(self.conf)
41-
#don't make hash dependent on number of epochs.
42+
#don't make hash dependent on number of epochs or T_min_warn as those can be modified
4243
this_conf['training']['num_epochs'] = 0
43-
unique_id = int(hashlib.md5((dill.dumps(this_conf).decode('unicode_escape')).encode('utf-8')).hexdigest(),16)
44+
this_conf['data']['T_min_warn'] = 30
45+
#unique_id = int(hashlib.md5((dill.dumps(this_conf).decode('unicode_escape')).encode('utf-8')).hexdigest(),16)
46+
unique_id = general_object_hash(this_conf)
4447
return unique_id
4548

49+
4650
def get_0D_1D_indices(self):
4751
#make sure all 1D indices are contiguous in the end!
4852
use_signals = self.conf['paths']['use_signals']
@@ -270,6 +274,8 @@ def get_all_saved_files(self):
270274
self.ensure_save_directory()
271275
unique_id = self.get_unique_id()
272276
filenames = os.listdir(self.conf['paths']['model_save_path'])
277+
print("All saved files with id {} and path {}".format(unique_id,self.conf['paths']['model_save_path']))
278+
print(filenames)
273279
epochs = []
274280
for file in filenames:
275281
curr_id,epoch = self.extract_id_and_epoch_from_filename(file)

plasma/models/mpi_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ def mpi_train(conf,shot_list_train,shot_list_validate,loader, callbacks_list=Non
731731
mpi_model.batch_iterator_func.__exit__()
732732
mpi_model.num_so_far_accum = mpi_model.num_so_far_indiv
733733
mpi_model.set_batch_iterator_func()
734+
734735
if 'monitor_test' in conf['callbacks'].keys() and conf['callbacks']['monitor_test']:
735736
conf_curr = deepcopy(conf)
736737
T_min_warn_orig = conf['data']['T_min_warn']

plasma/utils/downloading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import os
2727
import errno
2828

29+
import dill,hashlib
30+
2931
# import gadata
3032

3133
# from plasma.primitives.shots import ShotList
@@ -34,6 +36,26 @@
3436

3537
#print("Importing numpy version"+np.__version__)
3638

39+
def general_object_hash(o):
40+
"""Makes a hash from a dictionary, list, tuple or set to any level, that contains
41+
only other hashable types (including any lists, tuples, sets, and
42+
dictionaries). Relies on dill for serialization"""
43+
44+
if isinstance(o, (set, tuple, list)):
45+
return tuple([make_hash(e) for e in o])
46+
47+
elif not isinstance(o, dict):
48+
return myhash(o)
49+
50+
new_o = deepcopy(o)
51+
for k, v in new_o.items():
52+
new_o[k] = make_hash(v)
53+
54+
return myhash(tuple(frozenset(sorted(new_o.items()))))
55+
56+
def myhash(x):
57+
return int(hashlib.md5((dill.dumps(x).decode('unicode_escape')).encode('utf-8')).hexdigest(),16)
58+
3759

3860
def get_missing_value_array():
3961
return np.array([-1.0])

0 commit comments

Comments
 (0)