Skip to content

Commit 8ea406b

Browse files
committed
tests: fixes to complete test suite
1 parent bb14221 commit 8ea406b

File tree

2 files changed

+180
-3
lines changed

2 files changed

+180
-3
lines changed

tests/entity_extraction/test_entity_extraction_evaluation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def sample_labelled_entities():
2626
test_labelled_entities = [
2727
{"start": 13, "end": 33, "labels": ["ALTI"], "text": "120m above sea level"},
2828
{"start": 38, "end": 45, "labels": ["AGE"], "text": "1234 BP"},
29-
{"start": 56, "end": 66, "labels": ["TAXA"], "text": "Pediastrum"},
29+
{"start": 56, "end": 65, "labels": ["TAXA"], "text": "Pediastrum"},
3030
]
3131
return test_labelled_entities
3232

@@ -35,15 +35,17 @@ def sample_labelled_entities():
3535
def test_get_token_labels(sample_text, sample_labelled_entities):
3636
expected_non_null_labels = [3, 4, 5, 6, 8, 9, 12]
3737

38-
token_labels = get_token_labels(sample_labelled_entities, sample_text)
38+
split_text, token_labels = get_token_labels(sample_labelled_entities, sample_text)
3939

4040
for i in expected_non_null_labels:
4141
assert token_labels[i] != "O"
4242

4343

4444
# test the ideal case of passing in the same labelled tokens and predicted tokens
4545
def test_calculate_entity_classification_metrics(sample_text, sample_labelled_entities):
46-
sample_token_labels = get_token_labels(sample_labelled_entities, sample_text)
46+
split_text, sample_token_labels = get_token_labels(
47+
sample_labelled_entities, sample_text
48+
)
4749

4850
# test that the accuracy, f1, and recall scores are equal to 1
4951
accuracy, f1, recall, precision = calculate_entity_classification_metrics(
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import os
2+
import sys
3+
import pytest
4+
import logging
5+
from transformers import pipeline
6+
import pandas as pd
7+
8+
logger = logging.getLogger(__name__)
9+
10+
# ensure that the parent directory is on the path for relative imports
11+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
12+
13+
from src.entity_extraction.training.hf_token_classification.hf_evaluate import (
14+
get_hf_token_labels,
15+
get_predicted_labels,
16+
generate_classification_results,
17+
)
18+
19+
20+
@pytest.fixture
21+
def example_inputs():
22+
labelled_entities = [
23+
{"start": 0, "end": 8, "entity_group": "PERSON"},
24+
{"start": 18, "end": 24, "entity_group": "LOCATION"},
25+
]
26+
raw_text = "John Doe lives in London."
27+
return labelled_entities, raw_text
28+
29+
30+
@pytest.fixture
31+
def example_data():
32+
df = pd.DataFrame(
33+
{
34+
"tokens": [
35+
["This", "is", "an", "example", "sentence."],
36+
["Another", "example."],
37+
],
38+
}
39+
)
40+
return df
41+
42+
43+
@pytest.fixture
44+
def ner_pipe():
45+
ner_model_name = "dslim/bert-base-NER"
46+
tokenizer_name = "dslim/bert-base-NER"
47+
48+
ner_pipe = pipeline("ner", model=ner_model_name, tokenizer=tokenizer_name)
49+
yield ner_pipe
50+
51+
52+
@pytest.fixture
53+
def example_correct_tokens():
54+
true_tokens = [
55+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
56+
]
57+
predicted_tokens = [
58+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
59+
]
60+
61+
return true_tokens, predicted_tokens
62+
63+
64+
@pytest.fixture
65+
def example_incorrect_tokens():
66+
true_tokens = [
67+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
68+
]
69+
predicted_tokens = [
70+
["O", "B-TAXA", "O", "B-AGE"],
71+
]
72+
73+
return true_tokens, predicted_tokens
74+
75+
76+
def test_get_hf_token_labels(example_inputs):
77+
labelled_entities, raw_text = example_inputs
78+
79+
expected_split_text = ["John", "Doe", "lives", "in", "London."]
80+
expected_token_labels = ["B-PERSON", "I-PERSON", "O", "O", "B-LOCATION"]
81+
82+
split_text, token_labels = get_hf_token_labels(labelled_entities, raw_text)
83+
84+
assert split_text == expected_split_text
85+
assert token_labels == expected_token_labels
86+
87+
88+
def test_get_hf_token_labels_with_invalid_labelled_entities():
89+
labelled_entities = "invalid" # Invalid input type: should be a list
90+
raw_text = "Some text"
91+
92+
with pytest.raises(TypeError):
93+
get_hf_token_labels(labelled_entities, raw_text)
94+
95+
96+
def test_get_hf_token_labels_with_invalid_raw_text():
97+
labelled_entities = []
98+
raw_text = 123 # Invalid input type: should be a string
99+
100+
with pytest.raises(TypeError):
101+
get_hf_token_labels(labelled_entities, raw_text)
102+
103+
104+
def test_get_predicted_labels(example_data, ner_pipe):
105+
df = example_data.copy()
106+
107+
df = get_predicted_labels(ner_pipe, df)
108+
109+
assert "joined_text" in df.columns
110+
assert "predicted_labels" in df.columns
111+
assert "split_text" in df.columns
112+
assert "predicted_tokens" in df.columns
113+
assert len(df) == len(example_data)
114+
115+
116+
def test_get_predicted_labels_with_empty_dataframe(ner_pipe):
117+
df = pd.DataFrame() # Empty DataFrame
118+
119+
with pytest.raises(ValueError):
120+
get_predicted_labels(ner_pipe, df)
121+
122+
123+
def test_get_predicted_labels_with_missing_tokens_column(ner_pipe):
124+
df = pd.DataFrame({"text": ["This is an example sentence.", "Another example."]})
125+
126+
with pytest.raises(KeyError):
127+
get_predicted_labels(ner_pipe, df)
128+
129+
130+
def test_generate_classification_results_with_correct_input(example_correct_tokens):
131+
true_tokens, predicted_tokens = example_correct_tokens
132+
133+
results = generate_classification_results(true_tokens, predicted_tokens)
134+
135+
# ensure the f1, accuracy, recall and precision are correct to 2 decimal places
136+
assert round(results["token"]["f1"], 2) == 1.0
137+
assert round(results["token"]["accuracy"], 2) == 1.0
138+
assert round(results["token"]["recall"], 2) == 1.0
139+
assert round(results["token"]["precision"], 2) == 1.0
140+
assert round(results["entity"]["f1"], 2) == 1.0
141+
assert round(results["entity"]["accuracy"], 2) == 1.0
142+
assert round(results["entity"]["recall"], 2) == 1.0
143+
assert round(results["entity"]["precision"], 2) == 1.0
144+
145+
146+
def test_generate_classification_results_with_incorrect_input(example_incorrect_tokens):
147+
true_tokens, predicted_tokens = example_incorrect_tokens
148+
149+
results = generate_classification_results(true_tokens, predicted_tokens)
150+
151+
# ensure the f1, accuracy, recall and precision are correct to 2 decimpal places
152+
assert round(results["token"]["f1"], 2) == 0.80
153+
assert round(results["token"]["accuracy"], 2) == 0.75
154+
assert round(results["token"]["recall"], 2) == 0.67
155+
assert round(results["token"]["precision"], 2) == 1.0
156+
assert round(results["entity"]["f1"], 2) == 0.5
157+
assert round(results["entity"]["accuracy"], 2) == 0.5
158+
assert round(results["entity"]["recall"], 2) == 0.5
159+
assert round(results["entity"]["precision"], 2) == 0.5
160+
161+
162+
def test_generate_classification_results_with_empty_input():
163+
true_tokens = []
164+
predicted_tokens = []
165+
166+
with pytest.raises(ValueError):
167+
generate_classification_results(true_tokens, predicted_tokens)
168+
169+
170+
def test_generate_classification_results_with_invalid_input_lengths():
171+
true_tokens = [["B-TAXA", "I-TAXA", "O", "B-PER"]]
172+
predicted_tokens = [["B-TAXA", "I-TAXA", "O"]]
173+
174+
with pytest.raises(ValueError):
175+
generate_classification_results(true_tokens, predicted_tokens)

0 commit comments

Comments
 (0)