Skip to content

Commit b580584

Browse files
committed
json format fix
1 parent 80a74b8 commit b580584

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

src/article_relevance/relevance_prediction_model_retrain.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@
5252
import matplotlib.pyplot as plt
5353
from sklearn.metrics import confusion_matrix
5454

55-
5655
# Locate src module
57-
current_dir = os.path.dirname(os.path.abspath('__file__'))
56+
current_dir = os.path.dirname(os.path.abspath(__file__))
5857
src_dir = os.path.dirname(current_dir)
5958
sys.path.append(src_dir)
6059

@@ -150,7 +149,6 @@ def retrain_data_load_split(reviewed_parquet_folder_path):
150149
valid_df, test_df = train_test_split(val_test_df, test_size=0.5, random_state=123)
151150

152151
logger.info(f'Data Loading - Reviewed new sample has {train_df.shape[0]}/{valid_df.shape[0]}/{test_df.shape[0]} in train/valid/test splits.')
153-
logger.info(f'Data Loading - Each new sample has {train_df.shape[1]} features.')
154152

155153
return train_df, valid_df, test_df
156154

@@ -182,7 +180,6 @@ def retrain_data_merge(old_train, new_train, old_valid, new_valid,old_test, new_
182180
test_df = pd.concat([old_test, new_test], ignore_index=True)
183181

184182
logger.info(f'Data Loading - Final training sample has {train_df.shape[0]}/{valid_df.shape[0]}/{test_df.shape[0]} in train/valid/test splits.')
185-
logger.info(f'Data Loading - Each merged sample has {train_df.shape[1]} features.')
186183

187184
return train_df, valid_df, test_df
188185

@@ -331,26 +328,26 @@ def model_eval(model, valid_df, test_df, report_dir):
331328
results[f'thld_{thld}'] = {'valid_recall' : recall,
332329
'valid_precision': precision,
333330
'valid_f1': f1_score,
334-
'valid_TN': TN,
335-
'valid_FN': FN,
336-
'valid_TP': TP,
337-
'valid_FP': FP,
331+
'valid_TN': int(TN),
332+
'valid_FN': int(FN),
333+
'valid_TP': int(TP),
334+
'valid_FP': int(FP),
338335
}
339336

340337
# ======= Test set performnace, assuming using 0.5 threshold
341338
predictions = model.predict(X_test)
342339
TN, FP, FN, TP = confusion_matrix(y_test, predictions).ravel()
343-
results['test_performance_0.5'] = {'test_recall' : TP / (TP + FN),
344-
'test_precision': TP / (TP + FP),
345-
'test_f1': (2 * precision * recall) / (precision + recall),
346-
'test_TN': TN,
347-
'test_FN': FN,
348-
'test_TP': TP,
349-
'test_FP': FP,
340+
results['test_performance_0.5'] = {'test_recall' : round(TP / (TP + FN),3),
341+
'test_precision': round(TP / (TP + FP),3),
342+
'test_f1': round((2 * precision * recall) / (precision + recall), 3),
343+
'test_TN': int(TN),
344+
'test_FN': int(FN),
345+
'test_TP': int(TP),
346+
'test_FP': int(FP)
350347
}
351348

352-
logger.info(f'Evaluation - test recall = {TP / (TP + FN)}')
353-
logger.info(f'Evaluation - test precision = {TP / (TP + FP)}')
349+
logger.info(f'Evaluation - test recall = {round(TP / (TP + FN), 3)}')
350+
logger.info(f'Evaluation - test precision = {round(TP / (TP + FP), 3)}')
354351

355352
# convert to Json file and export
356353
report_file_path = os.path.join(report_dir, f"retrained_model_{formatted_datetime}_metrics.json")

0 commit comments

Comments
 (0)