Skip to content

Commit 9c62460

Browse files
committed
bug: fixed HF model load and paths
1 parent 2ffc78f commit 9c62460

1 file changed

Lines changed: 26 additions & 6 deletions

File tree

src/pipeline/entity_extraction_pipeline.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Author: Ty Andrews
22
# Date: 2023-06-05
33
"""
4-
Usage: entity_extraction_pipeline.py --article_text_path=<article_text_path> --output_path=<output_path>
4+
Usage: entity_extraction_pipeline.py --article_text_path=<article_text_path> --output_path=<output_path> [--max_sentences=<max_sentences>] [--max_articles=<max_articles>]
55
66
Options:
77
--article_text_path=<article_text_path> The path to the article text data file.
88
--output_path=<output_path> The path to export the extracted entities to.
9+
--max_sentences=<max_sentences> The maximum number of sentences to extract entities from. [default: -1]
10+
--max_articles <max_articles> The maximum number of articles to extract entities from. [default: -1]
911
"""
1012

1113
import os
@@ -32,7 +34,7 @@
3234
load_dotenv(find_dotenv())
3335

3436
# get the MODEL_NAME from environment variables
35-
HF_NER_MODEL_NAME = os.getenv("HF_NER_MODEL_NAME", "finding-fossils/metaextractor")
37+
HF_NER_MODEL_PATH = os.getenv("HF_NER_MODEL_PATH", "./models/ner/metaextractor")
3638
SPACY_NER_MODEL_NAME = os.getenv("SPACY_NER_MODEL_NAME", "en_metaextractor_spacy")
3739
USE_NER_MODEL_TYPE = os.getenv("USE_NER_MODEL_TYPE", "huggingface")
3840
MAX_SENTENCES = os.getenv("MAX_SENTENCES", "-1")
@@ -286,7 +288,7 @@ def recreate_original_sentences_with_labels(row):
286288
def extract_entities(
287289
article_text_data: pd.DataFrame,
288290
model_type: str = "huggingface",
289-
model_path: str = "finding-fossils/metaextractor",
291+
model_path: str = "metaextractor",
290292
) -> pd.DataFrame:
291293
"""
292294
Extracts the entities from the article text data.
@@ -562,19 +564,30 @@ def main():
562564
]
563565
)
564566
]
567+
logger.info(
568+
f"Using just a subsample of the data of with {int(MAX_ARTICLES)} articles"
569+
)
565570

566571
# if max_sentences is not -1 then only use the first max_sentences sentences
567572
if MAX_SENTENCES is not None and int(MAX_SENTENCES) != -1:
568-
article_text_data = article_text_data.head(int(MAX_SENTENCES))
573+
# get just sentence id's for each gdd up to max_sentences
574+
article_text_data = article_text_data[
575+
article_text_data["sentid"].isin(
576+
article_text_data["sentid"].unique()[0 : int(MAX_SENTENCES)]
577+
)
578+
]
579+
logger.info(
580+
f"Using just a subsample of the data of with {int(MAX_SENTENCES)} sentences"
581+
)
569582

570583
for article_gdd in article_text_data["gddid"].unique():
571584
logger.info(f"Processing GDD ID: {article_gdd}")
572585

573586
article_text = article_text_data[article_text_data["gddid"] == article_gdd]
574587

575588
if USE_NER_MODEL_TYPE == "huggingface":
576-
logger.info(f"Using HuggingFace model {HF_NER_MODEL_NAME}")
577-
model_path = HF_NER_MODEL_NAME
589+
logger.info(f"Using HuggingFace model {HF_NER_MODEL_PATH}")
590+
model_path = HF_NER_MODEL_PATH
578591
elif USE_NER_MODEL_TYPE == "spacy":
579592
logger.info(f"Using Spacy model {SPACY_NER_MODEL_NAME}")
580593
model_path = SPACY_NER_MODEL_NAME
@@ -611,6 +624,13 @@ def main():
611624
)
612625
continue
613626

627+
# delete the file if it already exists with the article_gdd name
628+
if os.path.exists(os.path.join(opt["--output_path"], f"{article_gdd}.json")):
629+
os.remove(os.path.join(opt["--output_path"], f"{article_gdd}.json"))
630+
logger.warning(
631+
f"Deleted existing file {article_gdd}.json in output directory."
632+
)
633+
614634
export_extracted_entities(
615635
extracted_entities=pprocessed_entities,
616636
output_path=opt["--output_path"],

0 commit comments

Comments
 (0)