@@ -72,7 +72,7 @@ def train_data_load_split(train_raw_csv_path):
7272 Three pandas Data frames: train_df, valid_df, test_df
7373
7474 Example:
75- train_data_load_split(train_raw_csv_path = "../../data/article-relevance/processed/metadata_processed_embedded .csv")
75+ train_data_load_split(train_raw_csv_path = "../../data/article-relevance/processed/metadata_embeded_specter2 .csv")
7676 '''
7777
7878 # load original training sample
@@ -81,14 +81,15 @@ def train_data_load_split(train_raw_csv_path):
8181 metadata_df ['text_with_abstract' ].fillna ("" , inplace = True )
8282 metadata_df ['subject_clean' ].fillna ("" , inplace = True )
8383
84- if metadata_df ['has_abstract' ].isna ().any ():
85- raise ValueError (f"Column 'has_abstract' contains NaN values." )
86- if metadata_df ['is-referenced-by-count' ].isna ().any ():
87- raise ValueError (f"Column 'is-referenced-by-count' contains NaN values." )
84+
8885 if metadata_df ['text_with_abstract' ].isna ().any ():
8986 raise ValueError (f"Column 'text_with_abstract' contains NaN values." )
9087 if metadata_df ['target' ].isna ().any ():
9188 raise ValueError (f"Column 'target' contains NaN values." )
89+ if metadata_df ['has_abstract' ].isna ().any ():
90+ raise ValueError (f"Column 'has_abstract' contains NaN values." )
91+ if metadata_df ['is-referenced-by-count' ].isna ().any ():
92+ raise ValueError (f"Column 'is-referenced-by-count' contains NaN values." )
9293
9394
9495 # Split into train/valid/test sets
@@ -251,6 +252,9 @@ def model_train(train_df, model_dir, model_c = 0.01563028103558011):
251252 now = datetime .datetime .now ()
252253 formatted_datetime = now .strftime ("%Y-%m-%dT%H-%M-%S" )
253254
255+ if not os .path .exists (model_dir ):
256+ os .makedirs (model_dir )
257+
254258 model_file_name = os .path .join (model_dir , f"retrained_model_{ formatted_datetime } .joblib" )
255259 joblib .dump (logreg_model , model_file_name )
256260
@@ -276,6 +280,9 @@ def model_eval(model, valid_df, test_df, report_dir):
276280 Return:
277281 None
278282 '''
283+
284+ if not os .path .exists (report_dir ):
285+ os .makedirs (report_dir )
279286
280287 # ======= Only keep feature columns ==========
281288 keep_col = ['target' , 'has_abstract' , 'subject_clean' , 'is-referenced-by-count' ] + [str (i ) for i in range (0 ,768 )]
@@ -350,6 +357,7 @@ def model_eval(model, valid_df, test_df, report_dir):
350357 logger .info (f'Evaluation - test precision = { round (TP / (TP + FP ), 3 )} ' )
351358
352359 # convert to Json file and export
360+
353361 report_file_path = os .path .join (report_dir , f"retrained_model_{ formatted_datetime } _metrics.json" )
354362
355363 with open (report_file_path , 'w' ) as json_file :
0 commit comments