55# Input: list of doi
66# output: df containing all metadata, predicted relevance, predict_proba
77
8- """This script takes a list of DOI as input and output a dataframe containing all metadata, predicted relevance, predict_proba of each article.
8+ """
9+ This script takes a list of DOI as input and output a dataframe containing
10+ all metadata, predicted relevance, predict_proba of each article.
911
10- Usage: relevance_prediction_parquet.py --doi_file_path=<doi_path> --model_path=<model_path> --output_path=<output_path> --send_xdd=<send_xdd>
12+ Usage: relevance_prediction_parquet.py --doi_file_path=<doi_path>
13+ --model_path=<model_path> --output_path=<output_path> --send_xdd=<send_xdd>
1114
1215Options:
1316 --doi_file_path=<doi_file_path> The path to where the list of DOI is.
1417 --model_path=<model_path> The path to where the model object is stored.
1518 --output_path=<output_path> The path to where the output files will be saved.
16- --send_xdd=<send_xdd> When True, relevant articles will be sent to xDD through API query. Default is False.
19+ --send_xdd=<send_xdd> When True, relevant articles will be sent to xDD
20+ through API query. Default is False.
1721"""
1822
19- import pandas as pd
20- import numpy as np
2123import os
22- import requests
2324import json
2425import sys
26+ import datetime
27+ import pandas as pd
28+ import numpy as np
29+ import requests
2530from langdetect import detect
2631from sentence_transformers import SentenceTransformer
2732import joblib
2833from docopt import docopt
2934import pyarrow as pa
3035import pyarrow .parquet as pq
31- import datetime
36+ from logs import get_logger
3237
3338# Locate src module
3439current_dir = os .path .dirname (os .path .abspath (__file__ ))
3540src_dir = os .path .dirname (current_dir )
3641sys .path .append (src_dir )
3742
38- from logs import get_logger
39-
4043logger = get_logger (__name__ ) # this gets the object with the current modules name
4144
4245
4346def crossref_extract (doi_path ):
4447 """Extract metadata from the Crossref API for article's in the doi csv file.
4548 Extracted data are returned in a pandas dataframe.
4649
47- If certain DOI is not found on CrossRef, the DOI will be logged in the prediction_pipeline.log file.
48-
50+ If certain DOI is not found on CrossRef, the DOI will be logged in the
51+ prediction_pipeline.log file.
52+
4953 Args:
5054 doi_path (str): Path to the doi list JSON file.
5155 doi_col (str): Column name of DOI.
52-
56+
5357 Return:
5458 pandas Dataframe containing CrossRef metadata.
5559 """
5660
57- logger .info (f' Running crossref_extract function.' )
61+ logger .info (" Running crossref_extract function." )
5862
5963
60- with open (doi_path ) as json_file :
64+ with open (doi_path , mode = "r" , encoding = "UTF-8" ) as json_file :
6165 data_dictionary = json .load (json_file )
6266
6367 df = pd .DataFrame (data_dictionary ['data' ])
6468
6569 if df .shape [0 ] == 0 :
66- logger .warning (f' Last xDD API query did not retrieve any article. Please verify the arguments.' )
70+ logger .warning (" Last xDD API query did not retrieve any article. Please verify the arguments." )
6771 raise ValueError ("No article to process. Script terminated." )
6872
6973 doi_col = 'DOI'
@@ -74,14 +78,16 @@ def crossref_extract(doi_path):
7478 # Initialize
7579 crossref = pd .DataFrame ()
7680
77- logger .info (f' Querying CrossRef API for article metadata.' )
81+ logger .info (" Querying CrossRef API for article metadata." )
7882
7983 # Loop through all doi, concatenate metadata into dataframe
8084 for doi in input_doi :
8185 cross_ref_url = f"https://api.crossref.org/works/{ doi } "
8286
8387 # make a request to the API
84- cross_ref_response = requests .get (cross_ref_url )
88+ cross_ref_response = requests .get (cross_ref_url , timeout = 5000 ,
89+ headers = {"User-Agent" :"""NeotomaArticleRelevanceTracker;
90+ (https://neotomadb.org; mailto:[email protected] )""" })
8591
8692 if cross_ref_response .status_code == 200 :
8793
@@ -92,12 +98,12 @@ def crossref_extract(doi_path):
9298 ref_df ['abstract' ] = ''
9399 crossref = pd .concat ([crossref , ref_df ])
94100
95- else :
101+ else :
96102 pass
97-
103+
98104 logger .info (f'CrossRef API query completed for { len (input_doi )} articles.' )
99105
100-
106+
101107 # Clean up columns and return the resulting pandas data frame
102108 crossref_keep_col = ['valid_for_prediction' , 'DOI' ,
103109 'URL' ,
@@ -107,12 +113,12 @@ def crossref_extract(doi_path):
107113 'is-referenced-by-count' , # times cited
108114 'language' ,
109115 'published' , # datetime
110- 'publisher' ,
116+ 'publisher' ,
111117 'subject' , # keywords of journal
112118 'subtitle' , # subtitle are missing sometimes
113119 'title'
114120 ]
115-
121+
116122 crossref = crossref .loc [:, crossref_keep_col ].reset_index (drop = True )
117123
118124
@@ -135,27 +141,27 @@ def crossref_extract(doi_path):
135141
136142
137143def en_only_helper (value ):
138- ''' Helper function for en_only.
144+ ''' Helper function for en_only.
139145 Apply row-wise to impute missing language data.'''
140-
146+
141147 try :
142148 detect_lang = detect (value )
143149 except :
144150 detect_lang = "error"
145151 logger .info ("This text throws an error:" , value )
146-
152+
147153 return detect_lang
148-
154+
149155
150156def data_preprocessing (metadata_df ):
151157 """
152158 Clean up title, subtitle, abstract, subject.
153159 Feature engineer for descriptive text column.
154160 Impute language.
155161 The outputted dataframe is ready to be used in model prediction.
156-
162+
157163 Args:
158- metadata_df (pd DataFrame): Input data frame.
164+ metadata_df (pd DataFrame): Input data frame.
159165
160166 Returns:
161167 pd DataFrame containing all info required for model prediction.
@@ -209,7 +215,7 @@ def data_preprocessing(metadata_df):
209215 metadata_df .loc [cannot_impute_condition , 'valid_for_prediction' ] = 0
210216 en_condition = (metadata_df ['language' ] != 'en' )
211217 metadata_df .loc [en_condition , 'valid_for_prediction' ] = 0
212-
218+
213219 logger .info ("Missing language imputation completed" )
214220 logger .info (f"After imputation, there are { metadata_df .loc [en_condition , :].shape [0 ]} non-English articles in total excluded from the prediction pipeline." )
215221
@@ -220,13 +226,13 @@ def data_preprocessing(metadata_df):
220226 'queryinfo_max_date' ,
221227 'queryinfo_term' ,
222228 'queryinfo_n_recent' ]
223-
229+
224230 metadata_df = metadata_df .loc [:, keep_col ]
225231
226232 metadata_df = metadata_df .rename (columns = {'title_clean' : 'title' ,
227233 'subtitle_clean' : 'subtitle' ,
228234 'abstract_clean' : 'abstract' })
229-
235+
230236 # invalid when required input field is Null
231237 mask = metadata_df [['text_with_abstract' , 'subject_clean' , 'is-referenced-by-count' , 'has_abstract' ]].isnull ().any (axis = 1 )
232238 metadata_df .loc [mask , 'valid_for_prediction' ] = 0
@@ -238,16 +244,15 @@ def data_preprocessing(metadata_df):
238244 logger .info (f'{ with_missing_df .shape [0 ]} articles has missing feature and its relevance cannot be predicted.' )
239245 logger .info (f'Data preprocessing completed.' )
240246
241-
242247 return metadata_df
243248
244249
245250def add_embeddings (input_df , text_col , model = 'allenai/specter2' ):
246251 """
247252 Add sentence embeddings to the dataframe using the specified model.
248-
253+
249254 Args:
250- input_df (pd DataFrame): Input data frame.
255+ input_df (pd DataFrame): Input data frame.
251256 text_col (str): Column with text feature.
252257 model(str): model name on hugging face model hub.
253258
@@ -279,26 +284,26 @@ def add_embeddings(input_df, text_col, model = 'allenai/specter2'):
279284
280285def relevance_prediction (input_df , model_path , predict_thld = 0.5 ):
281286 """
282- Make prediction on article relevancy.
287+ Make prediction on article relevancy.
283288 Add prediction and predict_proba to the resulting dataframe.
284289 Save resulting dataframe with all information in output_path directory.
285290 Return the resulting dataframe.
286291
287292 Args:
288- input_df (pd DataFrame): Input data frame.
293+ input_df (pd DataFrame): Input data frame.
289294 model_path (str): Directory to trained model object.
290295
291296 Returns:
292297 pd DataFrame with prediction and predict_proba added.
293298 """
294- logger .info (f' Prediction start.' )
295-
299+ logger .info (" Prediction start." )
300+
296301 try :
297302 # load model
298303 model_object = joblib .load (model_path )
299- except OSError :
304+ except OSError as exc :
300305 logger .error ("Model for article relevance not found." )
301- raise ( FileNotFoundError )
306+ raise FileNotFoundError from exc
302307
303308 # split by valid_for_prediction
304309 valid_df = input_df .query ('valid_for_prediction == 1' )
@@ -308,6 +313,7 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
308313
309314 # filter out rows with NaN value
310315 feature_col = ['has_abstract' , 'subject_clean' , 'is-referenced-by-count' ] + [str (i ) for i in range (0 ,768 )]
316+ logger .info (feature_col )
311317 nan_exists = valid_df .loc [:, feature_col ].isnull ().any (axis = 1 )
312318 df_nan_exist = valid_df .loc [nan_exists , :]
313319 valid_df .loc [nan_exists , 'valid_for_prediction' ] = 0
@@ -319,9 +325,9 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
319325
320326 # Filter results, store key information that could possibly be useful downstream
321327 keyinfo_col = (['DOI' , 'URL' , 'gddid' , 'valid_for_prediction' ,
322- 'prediction' , 'predict_proba' ] +
323- feature_col +
324- ['title' , 'subtitle' , 'abstract' , 'journal' ,
328+ 'prediction' , 'predict_proba' ] +
329+ feature_col +
330+ ['title' , 'subtitle' , 'abstract' , 'journal' ,
325331 'author' , 'text_with_abstract' , 'language' , 'published' , 'publisher' ,
326332 'queryinfo_min_date' ,
327333 'queryinfo_max_date' ,
@@ -335,7 +341,7 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
335341 'queryinfo_max_date' ,
336342 'queryinfo_term' ,
337343 'queryinfo_n_recent' ]
338-
344+
339345 keyinfo_df = valid_df .loc [:, keyinfo_col ]
340346
341347 # Join it with invalid df to get back to the full dataframe
@@ -352,9 +358,9 @@ def relevance_prediction(input_df, model_path, predict_thld = 0.5):
352358def xdd_put_request (row ):
353359 """
354360 If the article is predicted to be relevant, query xDD for full text.
355-
361+
356362 Args:
357- row (a row in pd DataFrame)
363+ row (a row in pd DataFrame)
358364
359365 Returns:
360366 'success' if the query was successful, otherwise 'failed'
@@ -372,7 +378,7 @@ def xdd_put_request(row):
372378 # ========= Mock output ========
373379 status = 200
374380
375- # =====
381+ # =====
376382 if status == 200 :
377383 return "success"
378384 else :
@@ -381,13 +387,13 @@ def xdd_put_request(row):
381387
382388def prediction_export (input_df , output_path ):
383389 """
384- Make prediction on article relevancy.
390+ Make prediction on article relevancy.
385391 Add prediction and predict_proba to the resulting dataframe.
386392 Save resulting dataframe with all information in output_path directory.
387393 Return the resulting dataframe.
388394
389395 Args:
390- input_df (pd DataFrame): Input data frame.
396+ input_df (pd DataFrame): Input data frame.
391397 model_path (str): Directory to trained model object.
392398
393399 Returns:
@@ -398,7 +404,7 @@ def prediction_export(input_df, output_path):
398404 parquet_folder = os .path .join (output_path , 'prediction_parquet' )
399405 if not os .path .exists (parquet_folder ):
400406 os .makedirs (parquet_folder )
401-
407+
402408 # Generate file name based on run date and batch
403409 now = datetime .datetime .now ()
404410 formatted_datetime = now .strftime ("%Y-%m-%dT%H-%M-%S" )
@@ -424,16 +430,16 @@ def main():
424430 doi_list_file_path = opt ["--doi_file_path" ]
425431 output_path = opt ['--output_path' ]
426432 send_xdd = opt ['--send_xdd' ]
427-
433+
428434 # # /models directory is a mounted volume, containing the model object
429435 # models = os.listdir("/models")
430436 # models = [f for f in models if f.endswith(".joblib")]
431-
437+
432438 # if models:
433439 # model_path = os.path.join("/models", models[0])
434440 # else:
435441 # model_path = ""
436-
442+
437443 model_path = opt ['--model_path' ]
438444
439445 metadata_df = crossref_extract (doi_list_file_path )
@@ -447,7 +453,7 @@ def main():
447453 if send_xdd == "True" :
448454 # run xdd_put_request function, add the xddquery_status column to the parquet
449455 predicted .loc [:, 'xdd_querystatus' ] = predicted .apply (xdd_put_request , axis = 1 )
450-
456+
451457 prediction_export (predicted , output_path )
452458
453459
0 commit comments