Skip to content

Commit cdec34a

Browse files
committed
Updating the relevance prediction, linting and comments.
No major changes to the code, just getting it to work locally and applying style changes.
1 parent f6e6e37 commit cdec34a

1 file changed

Lines changed: 60 additions & 54 deletions

File tree

src/article_relevance/relevance_prediction_parquet.py

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,69 @@
55
# Input: list of doi
66
# output: df containing all metadata, predicted relevance, predict_proba
77

8-
"""This script takes a list of DOI as input and output a dataframe containing all metadata, predicted relevance, predict_proba of each article.
8+
"""
9+
This script takes a list of DOI as input and output a dataframe containing
10+
all metadata, predicted relevance, predict_proba of each article.
911
10-
Usage: relevance_prediction_parquet.py --doi_file_path=<doi_path> --model_path=<model_path> --output_path=<output_path> --send_xdd=<send_xdd>
12+
Usage: relevance_prediction_parquet.py --doi_file_path=<doi_path>
13+
--model_path=<model_path> --output_path=<output_path> --send_xdd=<send_xdd>
1114
1215
Options:
1316
--doi_file_path=<doi_file_path> The path to where the list of DOI is.
1417
--model_path=<model_path> The path to where the model object is stored.
1518
--output_path=<output_path> The path to where the output files will be saved.
16-
--send_xdd=<send_xdd> When True, relevant articles will be sent to xDD through API query. Default is False.
19+
--send_xdd=<send_xdd> When True, relevant articles will be sent to xDD
20+
through API query. Default is False.
1721
"""
1822

19-
import pandas as pd
20-
import numpy as np
2123
import os
22-
import requests
2324
import json
2425
import sys
26+
import datetime
27+
import pandas as pd
28+
import numpy as np
29+
import requests
2530
from langdetect import detect
2631
from sentence_transformers import SentenceTransformer
2732
import joblib
2833
from docopt import docopt
2934
import pyarrow as pa
3035
import pyarrow.parquet as pq
31-
import datetime
36+
from logs import get_logger
3237

3338
# Locate src module
3439
current_dir = os.path.dirname(os.path.abspath(__file__))
3540
src_dir = os.path.dirname(current_dir)
3641
sys.path.append(src_dir)
3742

38-
from logs import get_logger
39-
4043
logger = get_logger(__name__) # this gets the object with the current modules name
4144

4245

4346
def crossref_extract(doi_path):
4447
"""Extract metadata from the Crossref API for article's in the doi csv file.
4548
Extracted data are returned in a pandas dataframe.
4649
47-
If certain DOI is not found on CrossRef, the DOI will be logged in the prediction_pipeline.log file.
48-
50+
If certain DOI is not found on CrossRef, the DOI will be logged in the
51+
prediction_pipeline.log file.
52+
4953
Args:
5054
doi_path (str): Path to the doi list JSON file.
5155
doi_col (str): Column name of DOI.
52-
56+
5357
Return:
5458
pandas Dataframe containing CrossRef metadata.
5559
"""
5660

57-
logger.info(f'Running crossref_extract function.')
61+
logger.info("Running crossref_extract function.")
5862

5963

60-
with open(doi_path) as json_file:
64+
with open(doi_path, mode = "r", encoding = "UTF-8") as json_file:
6165
data_dictionary = json.load(json_file)
6266

6367
df = pd.DataFrame(data_dictionary['data'])
6468

6569
if df.shape[0] == 0:
66-
logger.warning(f'Last xDD API query did not retrieve any article. Please verify the arguments.')
70+
logger.warning("Last xDD API query did not retrieve any article. Please verify the arguments.")
6771
raise ValueError("No article to process. Script terminated.")
6872

6973
doi_col = 'DOI'
@@ -74,14 +78,16 @@ def crossref_extract(doi_path):
7478
# Initialize
7579
crossref = pd.DataFrame()
7680

77-
logger.info(f'Querying CrossRef API for article metadata.')
81+
logger.info("Querying CrossRef API for article metadata.")
7882

7983
# Loop through all doi, concatenate metadata into dataframe
8084
for doi in input_doi:
8185
cross_ref_url = f"https://api.crossref.org/works/{doi}"
8286

8387
# make a request to the API
84-
cross_ref_response = requests.get(cross_ref_url)
88+
cross_ref_response = requests.get(cross_ref_url, timeout=5000,
89+
headers= {"User-Agent":"""NeotomaArticleRelevanceTracker;
90+
(https://neotomadb.org; mailto:[email protected])"""})
8591

8692
if cross_ref_response.status_code == 200:
8793

@@ -92,12 +98,12 @@ def crossref_extract(doi_path):
9298
ref_df['abstract'] = ''
9399
crossref = pd.concat([crossref, ref_df])
94100

95-
else:
101+
else:
96102
pass
97-
103+
98104
logger.info(f'CrossRef API query completed for {len(input_doi)} articles.')
99105

100-
106+
101107
# Clean up columns and return the resulting pandas data frame
102108
crossref_keep_col = ['valid_for_prediction', 'DOI',
103109
'URL',
@@ -107,12 +113,12 @@ def crossref_extract(doi_path):
107113
'is-referenced-by-count', # times cited
108114
'language',
109115
'published', # datetime
110-
'publisher',
116+
'publisher',
111117
'subject', # keywords of journal
112118
'subtitle', # subtitle are missing sometimes
113119
'title'
114120
]
115-
121+
116122
crossref = crossref.loc[:, crossref_keep_col].reset_index(drop = True)
117123

118124

@@ -135,27 +141,27 @@ def crossref_extract(doi_path):
135141

136142

137143
def en_only_helper(value):
138-
''' Helper function for en_only.
144+
''' Helper function for en_only.
139145
Apply row-wise to impute missing language data.'''
140-
146+
141147
try:
142148
detect_lang = detect(value)
143149
except:
144150
detect_lang = "error"
145151
logger.info("This text throws an error:", value)
146-
152+
147153
return detect_lang
148-
154+
149155

150156
def data_preprocessing(metadata_df):
151157
"""
152158
Clean up title, subtitle, abstract, subject.
153159
Feature engineer for descriptive text column.
154160
Impute language.
155161
The outputted dataframe is ready to be used in model prediction.
156-
162+
157163
Args:
158-
metadata_df (pd DataFrame): Input data frame.
164+
metadata_df (pd DataFrame): Input data frame.
159165
160166
Returns:
161167
pd DataFrame containing all info required for model prediction.
@@ -209,7 +215,7 @@ def data_preprocessing(metadata_df):
209215
metadata_df.loc[cannot_impute_condition, 'valid_for_prediction'] = 0
210216
en_condition = (metadata_df['language'] != 'en')
211217
metadata_df.loc[en_condition, 'valid_for_prediction'] = 0
212-
218+
213219
logger.info("Missing language imputation completed")
214220
logger.info(f"After imputation, there are {metadata_df.loc[en_condition, :].shape[0]} non-English articles in total excluded from the prediction pipeline.")
215221

@@ -220,13 +226,13 @@ def data_preprocessing(metadata_df):
220226
'queryinfo_max_date',
221227
'queryinfo_term',
222228
'queryinfo_n_recent']
223-
229+
224230
metadata_df = metadata_df.loc[:, keep_col]
225231

226232
metadata_df = metadata_df.rename(columns={'title_clean': 'title',
227233
'subtitle_clean': 'subtitle',
228234
'abstract_clean': 'abstract'})
229-
235+
230236
# invalid when required input field is Null
231237
mask = metadata_df[['text_with_abstract', 'subject_clean', 'is-referenced-by-count', 'has_abstract']].isnull().any(axis=1)
232238
metadata_df.loc[mask, 'valid_for_prediction'] = 0
@@ -238,16 +244,15 @@ def data_preprocessing(metadata_df):
238244
logger.info(f'{with_missing_df.shape[0]} articles has missing feature and its relevance cannot be predicted.')
239245
logger.info(f'Data preprocessing completed.')
240246

241-
242247
return metadata_df
243248

244249

245250
def add_embeddings(input_df, text_col, model = 'allenai/specter2'):
246251
"""
247252
Add sentence embeddings to the dataframe using the specified model.
248-
253+
249254
Args:
250-
input_df (pd DataFrame): Input data frame.
255+
input_df (pd DataFrame): Input data frame.
251256
text_col (str): Column with text feature.
252257
model(str): model name on hugging face model hub.
253258
@@ -279,26 +284,26 @@ def add_embeddings(input_df, text_col, model = 'allenai/specter2'):
279284

280285
def relevance_prediction(input_df, model_path, predict_thld = 0.5):
281286
"""
282-
Make prediction on article relevancy.
287+
Make prediction on article relevancy.
283288
Add prediction and predict_proba to the resulting dataframe.
284289
Save resulting dataframe with all information in output_path directory.
285290
Return the resulting dataframe.
286291
287292
Args:
288-
input_df (pd DataFrame): Input data frame.
293+
input_df (pd DataFrame): Input data frame.
289294
model_path (str): Directory to trained model object.
290295
291296
Returns:
292297
pd DataFrame with prediction and predict_proba added.
293298
"""
294-
logger.info(f'Prediction start.')
295-
299+
logger.info("Prediction start.")
300+
296301
try:
297302
# load model
298303
model_object = joblib.load(model_path)
299-
except OSError:
304+
except OSError as exc:
300305
logger.error("Model for article relevance not found.")
301-
raise(FileNotFoundError)
306+
raise FileNotFoundError from exc
302307

303308
# split by valid_for_prediction
304309
valid_df = input_df.query('valid_for_prediction == 1')
@@ -308,6 +313,7 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
308313

309314
# filter out rows with NaN value
310315
feature_col = ['has_abstract', 'subject_clean', 'is-referenced-by-count'] + [str(i) for i in range(0,768)]
316+
logger.info(feature_col)
311317
nan_exists = valid_df.loc[:, feature_col].isnull().any(axis = 1)
312318
df_nan_exist = valid_df.loc[nan_exists, :]
313319
valid_df.loc[nan_exists, 'valid_for_prediction'] = 0
@@ -319,9 +325,9 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
319325

320326
# Filter results, store key information that could possibly be useful downstream
321327
keyinfo_col = (['DOI', 'URL', 'gddid', 'valid_for_prediction',
322-
'prediction', 'predict_proba'] +
323-
feature_col +
324-
['title', 'subtitle', 'abstract', 'journal',
328+
'prediction', 'predict_proba'] +
329+
feature_col +
330+
['title', 'subtitle', 'abstract', 'journal',
325331
'author', 'text_with_abstract', 'language', 'published', 'publisher',
326332
'queryinfo_min_date',
327333
'queryinfo_max_date',
@@ -335,7 +341,7 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
335341
'queryinfo_max_date',
336342
'queryinfo_term',
337343
'queryinfo_n_recent']
338-
344+
339345
keyinfo_df = valid_df.loc[:, keyinfo_col]
340346

341347
# Join it with invalid df to get back to the full dataframe
@@ -352,9 +358,9 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
352358
def xdd_put_request(row):
353359
"""
354360
If the article is predicted to be relevant, query xDD for full text.
355-
361+
356362
Args:
357-
row (a row in pd DataFrame)
363+
row (a row in pd DataFrame)
358364
359365
Returns:
360366
'success' if the query was successful, otherwise 'failed'
@@ -372,7 +378,7 @@ def xdd_put_request(row):
372378
# ========= Mock output ========
373379
status = 200
374380

375-
# =====
381+
# =====
376382
if status == 200:
377383
return "success"
378384
else:
@@ -381,13 +387,13 @@ def xdd_put_request(row):
381387

382388
def prediction_export(input_df, output_path):
383389
"""
384-
Make prediction on article relevancy.
390+
Make prediction on article relevancy.
385391
Add prediction and predict_proba to the resulting dataframe.
386392
Save resulting dataframe with all information in output_path directory.
387393
Return the resulting dataframe.
388394
389395
Args:
390-
input_df (pd DataFrame): Input data frame.
396+
input_df (pd DataFrame): Input data frame.
391397
model_path (str): Directory to trained model object.
392398
393399
Returns:
@@ -398,7 +404,7 @@ def prediction_export(input_df, output_path):
398404
parquet_folder = os.path.join(output_path, 'prediction_parquet')
399405
if not os.path.exists(parquet_folder):
400406
os.makedirs(parquet_folder)
401-
407+
402408
# Generate file name based on run date and batch
403409
now = datetime.datetime.now()
404410
formatted_datetime = now.strftime("%Y-%m-%dT%H-%M-%S")
@@ -424,16 +430,16 @@ def main():
424430
doi_list_file_path = opt["--doi_file_path"]
425431
output_path = opt['--output_path']
426432
send_xdd = opt['--send_xdd']
427-
433+
428434
# # /models directory is a mounted volume, containing the model object
429435
# models = os.listdir("/models")
430436
# models = [f for f in models if f.endswith(".joblib")]
431-
437+
432438
# if models:
433439
# model_path = os.path.join("/models", models[0])
434440
# else:
435441
# model_path = ""
436-
442+
437443
model_path = opt['--model_path']
438444

439445
metadata_df = crossref_extract(doi_list_file_path)
@@ -447,7 +453,7 @@ def main():
447453
if send_xdd =="True":
448454
# run xdd_put_request function, add the xddquery_status column to the parquet
449455
predicted.loc[:, 'xdd_querystatus'] = predicted.apply(xdd_put_request, axis=1)
450-
456+
451457
prediction_export(predicted, output_path)
452458

453459

0 commit comments

Comments
 (0)