-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
68 lines (60 loc) · 3.45 KB
/
run.py
File metadata and controls
68 lines (60 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os, argparse
from src.Main import run_evaluation
from src.DatasetCreation import create_dataset_versions
from src.PreparingMethod import PreparingMethod
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', required=True, help='Directory to save results')
parser.add_argument('--data_dir', required=True, help='Base directory containing datasets')
parser.add_argument('--dataset', required=True, help='Name of the dataset to load')
parser.add_argument('--train_method', required=True, help='Method to apply to training data')
parser.add_argument('--test_method', required=True, help='Method to apply to testing data')
parser.add_argument('--group_duplicates', action='store_true', help='Whether to deduplicate records')
parser.add_argument('--use_gpu', action='store_true', help='Whether to use GPU accelartion')
parser.add_argument('--n_workers', required=True, help='Size of dask cluster to use')
parser.add_argument('--filter_by_record_id', action='store_false', help='Whether to filter by record id - however filtering by column is deprecated')
parser.add_argument('--percentages', required=True, help='Percentages for the dataset split')
parser.add_argument('--cache_only', action='store_true', help='Whether to only prepare and cache datasets without running evaluations')
args = parser.parse_args()
print(f"Saving to: {args.save_dir}")
print(f"Dataset: {args.dataset}")
print(f"Train method: {args.train_method}")
print(f"Test method: {args.test_method}")
print(f"Group duplicates: {args.group_duplicates}")
print(f"Filter by record id: {args.filter_by_record_id}")
print(f"Using GPU: {args.use_gpu}")
print(f"Using n_workers: {args.n_workers}")
print(f"Percentages: {args.percentages}")
print(f"Cache only: {args.cache_only}")
# Convert string arguments to appropriate PreparingMethod enum values
train_method = getattr(PreparingMethod, args.train_method)
test_method = getattr(PreparingMethod, args.test_method)
# Parse percentages string into three floats
try:
pct_values = [float(x) for x in args.percentages.replace(',', ' ').split()]
if len(pct_values) != 3:
raise ValueError("Percentages must have exactly three values.")
original_pct, generalized_pct, missing_pct = pct_values
except Exception as e:
raise ValueError(f"Error parsing percentages: {args.percentages} ({e})")
# Create dataset versions based on percentages
pct_str = f"{int(round(original_pct*100))}-{int(round(generalized_pct*100))}-{int(round(missing_pct*100))}"
print(f"Preparing data for {args.dataset} with {pct_str} split...")
seed=42
create_dataset_versions(args.dataset, original_pct, generalized_pct, missing_pct, seed, args.data_dir, filter_by_record_id=args.filter_by_record_id)
# Add dataset subdirectory to save_dir (matches behavior in run.job)
save_dir_with_dataset = os.path.join(args.save_dir, args.dataset)
# Run the evaluation with the provided parameters
run_evaluation(
n_workers=args.n_workers,
use_gpu=args.use_gpu,
save_dir=save_dir_with_dataset,
data_dir=args.data_dir,
dataset=args.dataset,
train_method=train_method,
test_method=test_method,
group_duplicates=args.group_duplicates,
filter_by_record_id=args.filter_by_record_id,
percentages=pct_str,
cache_only=args.cache_only,
)