Skip to content

Commit 6e5d0ac

Browse files
committed
retrain docker small fix
1 parent 465267c commit 6e5d0ac

3 files changed

Lines changed: 16 additions & 6 deletions

File tree

docker/article-relevance-retrain/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ COPY models/article-relevance ./models/article-relevance
1818

1919
# Copy the data folder into the container
2020
COPY data/article-relevance ./data/article-relevance
21+
COPY data/data-review-tool/processed ./data/data-review-tool/processed
22+
2123

2224
# Copy the shell script to the container
2325
COPY docker/article-relevance-retrain/run-retrain.sh .

docker/article-relevance-retrain/run-retrain.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
python src/relevance_prediction_model_retrain.py \
3+
python src/article_relevance/relevance_prediction_model_retrain.py \
44
--use_reviewed_data="$USE_REVIEWED_DATA" \
55
--train_data_path="$TRAIN_DATA_PATH" \
66
--model_folder="$MODEL_FOLDER" \

src/article_relevance/relevance_prediction_model_retrain.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def train_data_load_split(train_raw_csv_path):
7272
Three pandas Data frames: train_df, valid_df, test_df
7373
7474
Example:
75-
train_data_load_split(train_raw_csv_path = "../../data/article-relevance/processed/metadata_processed_embedded.csv")
75+
train_data_load_split(train_raw_csv_path = "../../data/article-relevance/processed/metadata_embeded_specter2.csv")
7676
'''
7777

7878
# load original training sample
@@ -81,14 +81,15 @@ def train_data_load_split(train_raw_csv_path):
8181
metadata_df['text_with_abstract'].fillna("", inplace=True)
8282
metadata_df['subject_clean'].fillna("", inplace=True)
8383

84-
if metadata_df['has_abstract'].isna().any():
85-
raise ValueError(f"Column 'has_abstract' contains NaN values.")
86-
if metadata_df['is-referenced-by-count'].isna().any():
87-
raise ValueError(f"Column 'is-referenced-by-count' contains NaN values.")
84+
8885
if metadata_df['text_with_abstract'].isna().any():
8986
raise ValueError(f"Column 'text_with_abstract' contains NaN values.")
9087
if metadata_df['target'].isna().any():
9188
raise ValueError(f"Column 'target' contains NaN values.")
89+
if metadata_df['has_abstract'].isna().any():
90+
raise ValueError(f"Column 'has_abstract' contains NaN values.")
91+
if metadata_df['is-referenced-by-count'].isna().any():
92+
raise ValueError(f"Column 'is-referenced-by-count' contains NaN values.")
9293

9394

9495
# Split into train/valid/test sets
@@ -251,6 +252,9 @@ def model_train(train_df, model_dir, model_c = 0.01563028103558011):
251252
now = datetime.datetime.now()
252253
formatted_datetime = now.strftime("%Y-%m-%dT%H-%M-%S")
253254

255+
if not os.path.exists(model_dir):
256+
os.makedirs(model_dir)
257+
254258
model_file_name = os.path.join(model_dir, f"retrained_model_{formatted_datetime}.joblib")
255259
joblib.dump(logreg_model, model_file_name)
256260

@@ -276,6 +280,9 @@ def model_eval(model, valid_df, test_df, report_dir):
276280
Return:
277281
None
278282
'''
283+
284+
if not os.path.exists(report_dir):
285+
os.makedirs(report_dir)
279286

280287
# ======= Only keep feature columns ==========
281288
keep_col = ['target', 'has_abstract', 'subject_clean', 'is-referenced-by-count'] + [str(i) for i in range(0,768)]
@@ -350,6 +357,7 @@ def model_eval(model, valid_df, test_df, report_dir):
350357
logger.info(f'Evaluation - test precision = {round(TP / (TP + FP), 3)}')
351358

352359
# convert to Json file and export
360+
353361
report_file_path = os.path.join(report_dir, f"retrained_model_{formatted_datetime}_metrics.json")
354362

355363
with open(report_file_path, 'w') as json_file:

0 commit comments

Comments
 (0)