Skip to content

Commit a6833c6

Browse files
authored
Merge pull request #99 from NeotomaDB/dev
Update entity name
2 parents 9e42bd2 + 19f9caf commit a6833c6

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/entity_extraction/preprocessing/huggingface_preprocess.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,22 @@ def convert_labelled_data_to_hf_format(
7474
labelled_chunks = []
7575

7676
for file in os.listdir(data_folder):
77-
# if file doesn't end with txt skip it
78-
if not file.endswith(".txt"):
79-
continue
80-
81-
with open(os.path.join(data_folder, file), "r") as f:
82-
task = json.load(f)
83-
8477
try:
85-
raw_text = task["task"]["data"]["text"]
86-
annotation_result = task["result"]
87-
gdd_id = task["task"]["data"]["gdd_id"]
88-
78+
if file.endswith(".txt"):
79+
with open(os.path.join(data_folder, file), "r") as f:
80+
task = json.load(f)
81+
annotation_result = task["result"]
82+
gdd_id = task["task"]["data"]["gdd_id"]
83+
raw_text = task["task"]["data"]["text"]
84+
elif file.endswith(".json"):
85+
with open(os.path.join(data_folder, file), "r") as f:
86+
task = json.load(f)
87+
annotation_result = task["result"]
88+
gdd_id = task["data"]["gdd_id"]
89+
raw_text = task["data"]["text"]
90+
else:
91+
continue
92+
8993
labelled_entities = [
9094
annotation["value"] for annotation in annotation_result
9195
]

src/entity_extraction/preprocessing/labelling_data_split.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,17 @@ def extract_parquet_file(labelled_file_path: str):
327327

328328
for ent_type in corrected_entities.keys():
329329
for entity in corrected_entities[ent_type].keys():
330+
if corrected_entities[ent_type][entity]['corrected_name']:
331+
entity_text = corrected_entities[ent_type][entity]['corrected_name']
332+
else:
333+
entity_text = entity
330334
for sentence in corrected_entities[ent_type][entity]['sentence']:
331335
if (sentence['char_index']['start'] != -1 and
332336
sentence['char_index']['end'] != -1):
333337
all_sentences[sentence['sentid']] = sentence['text']
334338
output_files[sentence['sentid']].append({
335339
"value": {
336-
"text": corrected_entities[ent_type][entity]['corrected_name'],
340+
"text": entity_text,
337341
"start": sentence['char_index']['start'],
338342
"end": sentence['char_index']['end'],
339343
"labels": [ent_type]

0 commit comments

Comments
 (0)