|
| 1 | +# Author: Ty Andrews |
| 2 | +# Date: 2023-05-30 |
| 3 | +"""This script manages custom evaluation of the fine tuned hugging face models. |
| 4 | +
|
| 5 | +Usage: hf_evaluate.py --data_path=<data_path> --model_path=<model_path> --output_path=<output_path> --model_name=<model_name> [--max_samples=<max_samples>] |
| 6 | +
|
| 7 | +Options: |
| 8 | + --data_path=<data_path> The path to the evaluation data in json format. |
| 9 | + --model_path=<model_path> The path to the model to load. |
| 10 | + --output_path=<output_path> The path to export the results & plots to. |
| 11 | + --model_name=<model_name> The name of the model. |
| 12 | + --max_samples=<max_samples> The maximum number of samples to evaluate, set to 1 for CPU testing. [default: None] |
| 13 | +""" |
| 14 | + |
| 15 | +import os, sys |
| 16 | + |
| 17 | +import pandas as pd |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +import time |
| 21 | +import json |
| 22 | +from tqdm import tqdm |
| 23 | +from docopt import docopt |
| 24 | +from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline |
| 25 | +import torch |
| 26 | + |
| 27 | +sys.path.append( |
| 28 | + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, os.pardir) |
| 29 | +) |
| 30 | + |
| 31 | +from src.entity_extraction.evaluation.entity_extraction_evaluation import ( |
| 32 | + calculate_entity_classification_metrics, |
| 33 | + plot_token_classification_report, |
| 34 | + generate_classification_results, |
| 35 | + generate_confusion_matrix, |
| 36 | + export_classification_results, |
| 37 | + export_classification_report_plots, |
| 38 | +) |
| 39 | +from src.logs import get_logger |
| 40 | + |
| 41 | +logger = get_logger(__name__) |
| 42 | + |
| 43 | + |
| 44 | +def get_hf_token_labels(labelled_entities, raw_text): |
| 45 | + """ |
| 46 | + Returns a list of labels per token in the raw text from hugging face generated labels. |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + labelled_entities : list |
| 51 | + A list of dictionaries containing the labelled entities including the |
| 52 | + start and end indices of the entity, the entity label, and the entity |
| 53 | + text. |
| 54 | + raw_text : str |
| 55 | + The raw text that the entities were extracted from. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + token_labels : list |
| 60 | + A list of labels per token in the raw text. |
| 61 | + """ |
| 62 | + # ensure raw_text is a string and labelled_entities is a list |
| 63 | + if not isinstance(raw_text, str): |
| 64 | + raise TypeError(f"raw_text must be a string. Got {type(raw_text)}") |
| 65 | + elif not isinstance(labelled_entities, list): |
| 66 | + raise TypeError( |
| 67 | + f"labelled_entities must be a list. Got {type(labelled_entities)}" |
| 68 | + ) |
| 69 | + |
| 70 | + # split the text by whitespace |
| 71 | + split_text = raw_text.split() |
| 72 | + |
| 73 | + # create a list of labels per token |
| 74 | + token_labels = ["O"] * len(split_text) |
| 75 | + |
| 76 | + for entity in labelled_entities: |
| 77 | + start = entity["start"] |
| 78 | + end = entity["end"] |
| 79 | + label = entity["entity_group"] |
| 80 | + |
| 81 | + # get the token indices that the entity spans |
| 82 | + token_start = len(raw_text[:start].split()) |
| 83 | + token_end = len(raw_text[:end].split()) |
| 84 | + |
| 85 | + try: |
| 86 | + # if the entity spans multiple tokens |
| 87 | + if token_start != token_end: |
| 88 | + token_labels[token_start] = f"B-{label}" |
| 89 | + for i in range(token_start + 1, token_end): |
| 90 | + token_labels[i] = f"I-{label}" |
| 91 | + else: |
| 92 | + token_labels[token_start] = f"B-{label}" |
| 93 | + except Exception as e: |
| 94 | + print(e) |
| 95 | + print("Error with entity: ", entity) |
| 96 | + print("Raw text: ", raw_text) |
| 97 | + print("Token start: ", token_start) |
| 98 | + print("Token end: ", token_end) |
| 99 | + |
| 100 | + return split_text, token_labels |
| 101 | + |
| 102 | + |
| 103 | +def load_ner_model_pipeline(model_path: str): |
| 104 | + """ |
| 105 | + Loads a hugging face named entity recognition model. |
| 106 | +
|
| 107 | + Parameters |
| 108 | + ---------- |
| 109 | + model_path : str |
| 110 | + The path to the model to load. |
| 111 | +
|
| 112 | + Returns |
| 113 | + ------- |
| 114 | + ner_pipe : transformers.pipelines.Pipeline |
| 115 | + The ner model pipeline. |
| 116 | + model : transformers.modeling_outputs.TokenClassifierOutput |
| 117 | + The loaded model. |
| 118 | + tokenizer : transformers.tokenization_bert.BertTokenizer |
| 119 | + The loaded tokenizer. |
| 120 | + """ |
| 121 | + |
| 122 | + device_str = "cuda:0" if torch.cuda.is_available() else "cpu" |
| 123 | + if "cuda" in device_str: |
| 124 | + logger.info("Using GPU for predictions, batch size of 32") |
| 125 | + batch_size = 32 |
| 126 | + else: |
| 127 | + logger.info("Using CPU for predictions, batch size of 1") |
| 128 | + batch_size = 1 |
| 129 | + |
| 130 | + # load the model |
| 131 | + model = AutoModelForTokenClassification.from_pretrained(model_path) |
| 132 | + tokenizer = AutoTokenizer.from_pretrained( |
| 133 | + model_path, model_max_length=512, padding=True, truncation=True |
| 134 | + ) |
| 135 | + ner_pipe = pipeline( |
| 136 | + "ner", |
| 137 | + model=model, |
| 138 | + tokenizer=tokenizer, |
| 139 | + device=torch.device(device_str), |
| 140 | + batch_size=batch_size, |
| 141 | + aggregation_strategy="simple", |
| 142 | + ) |
| 143 | + |
| 144 | + return ner_pipe, model, tokenizer |
| 145 | + |
| 146 | + |
| 147 | +def load_evaluation_data(data_file_path: str): |
| 148 | + """ |
| 149 | + Loads the evaluation data. |
| 150 | +
|
| 151 | + Parameters |
| 152 | + ---------- |
| 153 | + data_path : str |
| 154 | + The path to the evaluation data in json format. |
| 155 | +
|
| 156 | + Returns |
| 157 | + ------- |
| 158 | + df : pandas.DataFrame |
| 159 | + The evaluation data. |
| 160 | + """ |
| 161 | + |
| 162 | + # ensure the data file exists and is json file |
| 163 | + if not os.path.exists(data_file_path): |
| 164 | + raise Exception("Data file does not exist.") |
| 165 | + |
| 166 | + if not data_file_path.endswith(".json"): |
| 167 | + raise Exception("Data file must be json format.") |
| 168 | + |
| 169 | + # load the data |
| 170 | + df = pd.read_json(data_file_path, lines=True) |
| 171 | + |
| 172 | + return df |
| 173 | + |
| 174 | + |
| 175 | +def get_predicted_labels(ner_pipe, df): |
| 176 | + """ |
| 177 | + Gets the predicted labels from the hugging face model. |
| 178 | +
|
| 179 | + Parameters |
| 180 | + ---------- |
| 181 | + ner_pipe : transformers.pipelines.Pipeline |
| 182 | + The ner model pipeline. |
| 183 | + df : pandas.DataFrame |
| 184 | + The evaluation data. |
| 185 | +
|
| 186 | + Returns |
| 187 | + ------- |
| 188 | + df : pandas.DataFrame |
| 189 | + The evaluation data with the predicted labels added. |
| 190 | + """ |
| 191 | + |
| 192 | + if len(df) == 0: |
| 193 | + raise ValueError("The provided dataframe is empty.") |
| 194 | + |
| 195 | + # huggingface needs list of lists with strings for batch processing |
| 196 | + df["joined_text"] = df["tokens"].apply(lambda x: " ".join(x)) |
| 197 | + |
| 198 | + # time the excution |
| 199 | + start = time.time() |
| 200 | + predicted_labels = ner_pipe(df.joined_text.to_list()) |
| 201 | + df["predicted_labels"] = pd.Series(predicted_labels) |
| 202 | + logger.info( |
| 203 | + f"Prediction time for {len(df)} chunks: {time.time() - start:.2f} seconds" |
| 204 | + ) |
| 205 | + |
| 206 | + df[["split_text", "predicted_tokens"]] = df.apply( |
| 207 | + lambda row: get_hf_token_labels(row.predicted_labels, row.joined_text), |
| 208 | + axis="columns", |
| 209 | + result_type="expand", |
| 210 | + ) |
| 211 | + |
| 212 | + return df |
| 213 | + |
| 214 | + |
| 215 | +def main(): |
| 216 | + opt = docopt(__doc__) |
| 217 | + |
| 218 | + # run evaluation for each json file in the data directory |
| 219 | + for file in os.listdir(opt["--data_path"]): |
| 220 | + # skip non json files and only ones that contain the words train/val/test |
| 221 | + if not file.endswith(".json") or ( |
| 222 | + "train" not in file and "val" not in file and "test" not in file |
| 223 | + ): |
| 224 | + continue |
| 225 | + logger.info(f"Evaluating {file}") |
| 226 | + file_name = file.split(".")[0] |
| 227 | + |
| 228 | + # load the evaluation data |
| 229 | + df = load_evaluation_data(os.path.join(opt["--data_path"], file)) |
| 230 | + |
| 231 | + if opt["--max_samples"] != "None": |
| 232 | + logger.info( |
| 233 | + f"Using just a subsample of the data of size {opt['--max_samples']}" |
| 234 | + ) |
| 235 | + # reset index and drop it |
| 236 | + df = df.sample(int(opt["--max_samples"])).reset_index(drop=True) |
| 237 | + |
| 238 | + # load the model |
| 239 | + ner_pipe, model, tokenizer = load_ner_model_pipeline(opt["--model_path"]) |
| 240 | + |
| 241 | + logger.info("Loaded model, generating predictions, this may take a while.") |
| 242 | + # get the predicted labels |
| 243 | + df = get_predicted_labels(ner_pipe, df) |
| 244 | + |
| 245 | + logger.info("Generated predictions, calculating classification results") |
| 246 | + # get the classification results |
| 247 | + classification_results = generate_classification_results( |
| 248 | + df.ner_tags.tolist(), df.predicted_tokens.tolist() |
| 249 | + ) |
| 250 | + |
| 251 | + # export the classification results |
| 252 | + export_classification_results( |
| 253 | + classification_results, |
| 254 | + opt["--output_path"], |
| 255 | + opt["--model_name"] + "_" + file_name, |
| 256 | + ) |
| 257 | + |
| 258 | + # export the classification report plots |
| 259 | + export_classification_report_plots( |
| 260 | + true_tokens=df.ner_tags.tolist(), |
| 261 | + predicted_tokens=df.predicted_tokens.tolist(), |
| 262 | + output_path=opt["--output_path"], |
| 263 | + model_name=opt["--model_name"] + "_" + file_name, |
| 264 | + ) |
| 265 | + |
| 266 | + generate_confusion_matrix( |
| 267 | + labelled_tokens=df.ner_tags.tolist(), |
| 268 | + predicted_tokens=df.predicted_tokens.tolist(), |
| 269 | + output_path=opt["--output_path"], |
| 270 | + model_name=opt["--model_name"] + "_" + file_name, |
| 271 | + ) |
| 272 | + |
| 273 | + |
| 274 | +if __name__ == "__main__": |
| 275 | + main() |
0 commit comments