Skip to content

Commit 80a74b8

Browse files
committed
add main
1 parent a30d878 commit 80a74b8

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

src/article_relevance/relevance_prediction_model_retrain.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,17 @@ def model_eval(model, valid_df, test_df, report_dir):
349349
'test_FP': FP,
350350
}
351351

352+
logger.info(f'Evaluation - test recall = {TP / (TP + FN)}')
353+
logger.info(f'Evaluation - test precision = {TP / (TP + FP)}')
354+
352355
# convert to Json file and export
353356
report_file_path = os.path.join(report_dir, f"retrained_model_{formatted_datetime}_metrics.json")
354357

355358
with open(report_file_path, 'w') as json_file:
356359
json.dump(results, json_file)
357360

361+
logger.info(f'Evaluation - Completed. Results saved in specified folder.')
362+
358363

359364
def main():
360365

@@ -381,6 +386,10 @@ def main():
381386
train_df_merged, valid_df_merged, test_df_merged = retrain_data_merge(train_df_old, train_df_new, valid_df_old, valid_df_new, test_df_old, test_df_new)
382387

383388
retrained_model = model_train(train_df_merged, model_folder, model_c = 0.01563028103558011)
384-
389+
385390
model_eval(retrained_model, valid_df_merged, test_df_merged, result_dir)
386391

392+
393+
394+
if __name__ == "__main__":
395+
main()

0 commit comments

Comments
 (0)