-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathget_data.py
More file actions
32 lines (25 loc) · 1.15 KB
/
get_data.py
File metadata and controls
32 lines (25 loc) · 1.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pandas as pd
from config import VOCABULARY, SENTENCE_LEN
from utils import generate_sentence_pairs
VAL_SIZE = 5000
TEST_SIZE = 1000
# Generate the validation set, check that there are no duplicates
val_set = {"src": [], "tgt_shifted": [], "tgt": []}
while len(val_set["src"]) < VAL_SIZE:
src, tgt_shifted, tgt = generate_sentence_pairs(VOCABULARY, SENTENCE_LEN)
if src not in val_set["src"]:
val_set["src"].append(src)
val_set["tgt_shifted"].append(tgt_shifted)
val_set["tgt"].append(tgt)
# Generate the test set, check that there are no duplicates and overlaps with the validation set
test_set = {"src": [], "tgt_shifted": [], "tgt": []}
while len(test_set["src"]) < TEST_SIZE:
src, tgt_shifted, tgt = generate_sentence_pairs(VOCABULARY, SENTENCE_LEN)
if src not in val_set["src"] and src not in test_set["src"]:
test_set["src"].append(src)
test_set["tgt_shifted"].append(tgt_shifted)
test_set["tgt"].append(tgt)
# Save the validation and test sets
pd.DataFrame(val_set).to_csv("val.csv", index=False)
pd.DataFrame(test_set).to_csv("test.csv", index=False)
print("Data generated successfully!")