forked from azk0019/CourseProject
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_model.py
More file actions
40 lines (30 loc) · 1.33 KB
/
eval_model.py
File metadata and controls
40 lines (30 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import torch
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from model import ClassificationModel
def eval(eval_dataset):
iterator = tqdm(enumerate(eval_dataset.items()), total=len(eval_dataset))
preds, labels = [], []
for i, (k, v) in iterator:
for val in v:
input = val["title"] + ". " + val["summary"]
tokens = tokenizer(input, add_special_tokens=False, padding=True, return_tensors="pt")
embeddings = bert(tokens["input_ids"], attention_mask=tokens["attention_mask"]).pooler_output
logits = model(embeddings)
preds.append(torch.argmax(logits).item())
labels.append(i)
print(classification_report(labels, preds))
with open("metrics.txt", "w") as f:
f.write(classification_report(labels, preds))
if __name__ == "__main__":
with open("test.json") as f:
eval_dataset = json.load(f)
tokenizer = BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
bert = BertModel.from_pretrained("allenai/scibert_scivocab_uncased")
state_dict = torch.load("./model.pt")
model = ClassificationModel()
model.load_state_dict(state_dict)
model.eval()
eval(eval_dataset)