Skip to content

Commit 67088c4

Browse files
committed
tests: added tests for calculate/plot methods
1 parent da53411 commit 67088c4

1 file changed

Lines changed: 77 additions & 10 deletions

File tree

tests/entity_extraction/test_entity_extraction_evaluation.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@ def sample_labelled_entities():
3131
return test_labelled_entities
3232

3333

34+
@pytest.fixture
35+
def example_correct_tokens():
36+
true_tokens = [
37+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
38+
]
39+
predicted_tokens = [
40+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
41+
]
42+
43+
return true_tokens, predicted_tokens
44+
45+
46+
@pytest.fixture
47+
def example_incorrect_tokens():
48+
true_tokens = [
49+
["B-TAXA", "I-TAXA", "O", "B-AGE"],
50+
]
51+
predicted_tokens = [
52+
["O", "B-TAXA", "O", "B-AGE"],
53+
]
54+
55+
return true_tokens, predicted_tokens
56+
57+
3458
# first test that the correct tokens are labelled as entities and not "O"
3559
def test_get_token_labels(sample_text, sample_labelled_entities):
3660
expected_non_null_labels = [3, 4, 5, 6, 8, 9, 12]
@@ -41,18 +65,61 @@ def test_get_token_labels(sample_text, sample_labelled_entities):
4165
assert token_labels[i] != "O"
4266

4367

44-
# test the ideal case of passing in the same labelled tokens and predicted tokens
45-
def test_calculate_entity_classification_metrics(sample_text, sample_labelled_entities):
46-
split_text, sample_token_labels = get_token_labels(
47-
sample_labelled_entities, sample_text
68+
def test_calculate_entity_classification_metrics_with_correct_input(
69+
example_correct_tokens,
70+
):
71+
true_tokens, predicted_tokens = example_correct_tokens
72+
73+
accuracy, f1, recall, precision = calculate_entity_classification_metrics(
74+
true_tokens, predicted_tokens, method="tokens"
4875
)
4976

50-
# test that the accuracy, f1, and recall scores are equal to 1
77+
# ensure the f1, accuracy, recall and precision are correct to 2 decimal places
78+
assert round(f1, 2) == 1.0
79+
assert round(accuracy, 2) == 1.0
80+
assert round(recall, 2) == 1.0
81+
assert round(precision, 2) == 1.0
82+
5183
accuracy, f1, recall, precision = calculate_entity_classification_metrics(
52-
sample_token_labels, sample_token_labels, method="tokens"
84+
true_tokens, predicted_tokens, method="entity"
85+
)
86+
# ensure the f1, accuracy, recall and precision are correct to 2 decimal places
87+
assert round(f1, 2) == 1.0
88+
assert round(accuracy, 2) == 1.0
89+
assert round(recall, 2) == 1.0
90+
assert round(precision, 2) == 1.0
91+
92+
93+
def test_calculate_entity_classification_metrics_with_incorrect_input(
94+
example_incorrect_tokens,
95+
):
96+
true_tokens, predicted_tokens = example_incorrect_tokens
97+
98+
accuracy, f1, recall, precision = calculate_entity_classification_metrics(
99+
true_tokens, predicted_tokens, method="tokens"
100+
)
101+
# ensure the f1, accuracy, recall and precision are correct to 2 decimpal places
102+
assert round(f1, 2) == 0.8
103+
assert round(accuracy, 2) == 0.75
104+
assert round(recall, 2) == 0.67
105+
assert round(precision, 2) == 1.0
106+
107+
accuracy, f1, recall, precision = calculate_entity_classification_metrics(
108+
true_tokens, predicted_tokens, method="entity"
109+
)
110+
# ensure the f1, accuracy, recall and precision are correct to 2 decimpal places
111+
assert round(f1, 2) == 0.5
112+
assert round(accuracy, 2) == 0.5
113+
assert round(recall, 2) == 0.5
114+
assert round(precision, 2) == 0.5
115+
116+
117+
def test_plot_classification_report(example_correct_tokens):
118+
true_tokens, predicted_tokens = example_correct_tokens
119+
120+
plot = plot_token_classification_report(
121+
true_tokens, predicted_tokens, title="Test Plot", method="tokens", display=False
53122
)
54123

55-
assert accuracy == 1
56-
assert f1 == 1
57-
assert recall == 1
58-
assert precision == 1
124+
assert plot is not None
125+
assert plot.axes[0].get_title() == "Test Plot"

0 commit comments

Comments
 (0)