@@ -31,6 +31,30 @@ def sample_labelled_entities():
3131 return test_labelled_entities
3232
3333
34+ @pytest .fixture
35+ def example_correct_tokens ():
36+ true_tokens = [
37+ ["B-TAXA" , "I-TAXA" , "O" , "B-AGE" ],
38+ ]
39+ predicted_tokens = [
40+ ["B-TAXA" , "I-TAXA" , "O" , "B-AGE" ],
41+ ]
42+
43+ return true_tokens , predicted_tokens
44+
45+
46+ @pytest .fixture
47+ def example_incorrect_tokens ():
48+ true_tokens = [
49+ ["B-TAXA" , "I-TAXA" , "O" , "B-AGE" ],
50+ ]
51+ predicted_tokens = [
52+ ["O" , "B-TAXA" , "O" , "B-AGE" ],
53+ ]
54+
55+ return true_tokens , predicted_tokens
56+
57+
3458# first test that the correct tokens are labelled as entities and not "O"
3559def test_get_token_labels (sample_text , sample_labelled_entities ):
3660 expected_non_null_labels = [3 , 4 , 5 , 6 , 8 , 9 , 12 ]
@@ -41,18 +65,61 @@ def test_get_token_labels(sample_text, sample_labelled_entities):
4165 assert token_labels [i ] != "O"
4266
4367
44- # test the ideal case of passing in the same labelled tokens and predicted tokens
45- def test_calculate_entity_classification_metrics (sample_text , sample_labelled_entities ):
46- split_text , sample_token_labels = get_token_labels (
47- sample_labelled_entities , sample_text
68+ def test_calculate_entity_classification_metrics_with_correct_input (
69+ example_correct_tokens ,
70+ ):
71+ true_tokens , predicted_tokens = example_correct_tokens
72+
73+ accuracy , f1 , recall , precision = calculate_entity_classification_metrics (
74+ true_tokens , predicted_tokens , method = "tokens"
4875 )
4976
50- # test that the accuracy, f1, and recall scores are equal to 1
77+ # ensure the f1, accuracy, recall and precision are correct to 2 decimal places
78+ assert round (f1 , 2 ) == 1.0
79+ assert round (accuracy , 2 ) == 1.0
80+ assert round (recall , 2 ) == 1.0
81+ assert round (precision , 2 ) == 1.0
82+
5183 accuracy , f1 , recall , precision = calculate_entity_classification_metrics (
52- sample_token_labels , sample_token_labels , method = "tokens"
84+ true_tokens , predicted_tokens , method = "entity"
85+ )
86+ # ensure the f1, accuracy, recall and precision are correct to 2 decimal places
87+ assert round (f1 , 2 ) == 1.0
88+ assert round (accuracy , 2 ) == 1.0
89+ assert round (recall , 2 ) == 1.0
90+ assert round (precision , 2 ) == 1.0
91+
92+
93+ def test_calculate_entity_classification_metrics_with_incorrect_input (
94+ example_incorrect_tokens ,
95+ ):
96+ true_tokens , predicted_tokens = example_incorrect_tokens
97+
98+ accuracy , f1 , recall , precision = calculate_entity_classification_metrics (
99+ true_tokens , predicted_tokens , method = "tokens"
100+ )
101+ # ensure the f1, accuracy, recall and precision are correct to 2 decimpal places
102+ assert round (f1 , 2 ) == 0.8
103+ assert round (accuracy , 2 ) == 0.75
104+ assert round (recall , 2 ) == 0.67
105+ assert round (precision , 2 ) == 1.0
106+
107+ accuracy , f1 , recall , precision = calculate_entity_classification_metrics (
108+ true_tokens , predicted_tokens , method = "entity"
109+ )
110+ # ensure the f1, accuracy, recall and precision are correct to 2 decimpal places
111+ assert round (f1 , 2 ) == 0.5
112+ assert round (accuracy , 2 ) == 0.5
113+ assert round (recall , 2 ) == 0.5
114+ assert round (precision , 2 ) == 0.5
115+
116+
117+ def test_plot_classification_report (example_correct_tokens ):
118+ true_tokens , predicted_tokens = example_correct_tokens
119+
120+ plot = plot_token_classification_report (
121+ true_tokens , predicted_tokens , title = "Test Plot" , method = "tokens" , display = False
53122 )
54123
55- assert accuracy == 1
56- assert f1 == 1
57- assert recall == 1
58- assert precision == 1
124+ assert plot is not None
125+ assert plot .axes [0 ].get_title () == "Test Plot"
0 commit comments