-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_splits.py
More file actions
316 lines (251 loc) · 11.2 KB
/
generate_splits.py
File metadata and controls
316 lines (251 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#!/usr/bin/env python3
"""
Generate train/val/test split indices for PRISM datasets.
This script loads an actual dataset and creates reproducible train/validation/test splits
that are saved as JSON files. These JSON files can then be used by DataPreparer to
ensure consistent splits across different training runs.
Usage:
python generate_splits.py --config configs/prism_default.yaml
python generate_splits.py --data-path data/tahoe_4M_moa-encoded.parquet --output data/splits/
"""
import argparse
import json
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from typing import Dict, List, Tuple
import sys
import os
# Add the parent directory to sys.path to import prism modules
sys.path.insert(0, str(Path(__file__).parent))
try:
from prism.configs import Config
except ImportError:
print("Warning: Could not import prism.configs. Using fallback configuration.")
Config = None
def load_dataset_info(data_path: str) -> Tuple[int, np.ndarray]:
"""
Load dataset to get sample count and labels for stratified splitting.
Args:
data_path: Path to the parquet file
Returns:
Tuple of (n_samples, labels_array)
"""
print(f"Loading dataset from: {data_path}")
# Load the parquet file
df = pd.read_parquet(data_path)
n_samples = len(df)
print(f"Dataset contains {n_samples:,} samples")
print(f"Dataset columns: {list(df.columns)}")
# Try to find the MoA labels column
moa_column = None
possible_moa_columns = ['moa-fine', 'moa-fine-encoded', 'moa', 'moa_fine', 'labels']
for col in possible_moa_columns:
if col in df.columns:
moa_column = col
break
if moa_column is None:
print("Warning: No MoA column found. Using random stratification.")
# Create dummy labels for stratification
labels = np.random.randint(0, 10, n_samples)
else:
print(f"Using column '{moa_column}' for stratified splitting")
labels = df[moa_column].values
# Handle string labels
if isinstance(labels[0], str):
unique_labels = np.unique(labels)
print(f"Found {len(unique_labels)} unique string labels, converting to integers")
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
labels = np.array([label_to_idx[label] for label in labels])
print(f"Label distribution: {np.bincount(labels)}")
return n_samples, labels
def create_stratified_splits(n_samples: int, labels: np.ndarray,
train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1,
random_state: int = 42) -> Dict[str, List[int]]:
"""
Create stratified train/validation/test splits.
Args:
n_samples: Total number of samples
labels: Array of labels for stratification
train_ratio: Proportion for training set
val_ratio: Proportion for validation set
test_ratio: Proportion for test set
random_state: Random seed for reproducibility
Returns:
Dictionary with 'train', 'val', 'test' keys containing lists of indices
"""
# Validate ratios
total_ratio = train_ratio + val_ratio + test_ratio
if abs(total_ratio - 1.0) > 1e-6:
raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
print(f"Creating splits with ratios - Train: {train_ratio}, Val: {val_ratio}, Test: {test_ratio}")
# Create index array
indices = np.arange(n_samples)
# First split: separate test set
if test_ratio > 0:
train_val_indices, test_indices = train_test_split(
indices,
test_size=test_ratio,
random_state=random_state,
stratify=labels
)
else:
train_val_indices = indices
test_indices = np.array([])
# Second split: separate train and validation from remaining data
if val_ratio > 0 and len(train_val_indices) > 0:
# Calculate validation ratio relative to remaining data
remaining_ratio = train_ratio + val_ratio
val_ratio_adjusted = val_ratio / remaining_ratio
train_indices, val_indices = train_test_split(
train_val_indices,
test_size=val_ratio_adjusted,
random_state=random_state,
stratify=labels[train_val_indices]
)
else:
train_indices = train_val_indices
val_indices = np.array([])
# Create the splits dictionary
splits = {
'train': train_indices.tolist(),
'val': val_indices.tolist(),
'test': test_indices.tolist()
}
# Print split statistics
print(f"Split sizes:")
print(f" Train: {len(splits['train'])} samples ({len(splits['train'])/n_samples:.1%})")
print(f" Val: {len(splits['val'])} samples ({len(splits['val'])/n_samples:.1%})")
print(f" Test: {len(splits['test'])} samples ({len(splits['test'])/n_samples:.1%})")
# Verify no overlap and all samples covered
all_indices = set(splits['train'] + splits['val'] + splits['test'])
assert len(all_indices) == n_samples, "Not all samples are covered by splits"
assert len(all_indices) == len(splits['train']) + len(splits['val']) + len(splits['test']), "Overlapping indices detected"
print("✓ Split validation passed")
return splits
def save_splits_json(splits: Dict[str, List[int]], output_path: str,
metadata: Dict = None) -> None:
"""
Save splits to JSON file with optional metadata.
Args:
splits: Dictionary containing the splits
output_path: Path where to save the JSON file
metadata: Optional metadata to include in the file
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Prepare the data to save
data_to_save = {
'splits': splits,
'metadata': metadata or {}
}
# Add generation timestamp
from datetime import datetime
data_to_save['metadata']['generated_at'] = datetime.now().isoformat()
data_to_save['metadata']['total_samples'] = len(splits['train']) + len(splits['val']) + len(splits['test'])
# Save to JSON
with open(output_path, 'w') as f:
json.dump(data_to_save, f, indent=2)
print(f"Splits saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(description="Generate train/val/test splits for PRISM datasets")
# Input options
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--config', type=str, help='Path to PRISM config file')
group.add_argument('--data-path', type=str, help='Direct path to dataset parquet file')
# Output options
parser.add_argument('--output', type=str, help='Output directory or file path')
parser.add_argument('--filename', type=str, default='train_val_test_splits.json',
help='Output filename (default: train_val_test_splits.json)')
# Split ratio options
parser.add_argument('--train-ratio', type=float, default=0.8, help='Training set ratio (default: 0.8)')
parser.add_argument('--val-ratio', type=float, default=0.1, help='Validation set ratio (default: 0.1)')
parser.add_argument('--test-ratio', type=float, default=0.1, help='Test set ratio (default: 0.1)')
# Other options
parser.add_argument('--random-seed', type=int, default=42, help='Random seed for reproducibility (default: 42)')
parser.add_argument('--force', action='store_true', help='Overwrite existing split files')
args = parser.parse_args()
# Determine data path
if args.config:
if Config is None:
print("Error: Cannot load config file without prism.configs module")
sys.exit(1)
config = Config(config_files=[args.config])
data_path = config.get('data.processed_parquet_path') or config.get('data.raw_parquet_path')
if not data_path:
print("Error: No data path found in config file")
sys.exit(1)
# Determine output path from config if not specified
if not args.output:
args.output = config.get('data.splits_path', 'data/splits/train_val_test_splits.json')
else:
data_path = args.data_path
# Determine output path if not specified
if not args.output:
data_dir = Path(data_path).parent
args.output = data_dir / 'splits' / args.filename
# Ensure data file exists
if not Path(data_path).exists():
print(f"Error: Data file not found: {data_path}")
sys.exit(1)
# Determine final output path
output_path = Path(args.output)
if output_path.is_dir():
output_path = output_path / args.filename
# Check if output file already exists
if output_path.exists() and not args.force:
print(f"Error: Output file already exists: {output_path}")
print("Use --force to overwrite")
sys.exit(1)
print(f"=== PRISM Dataset Split Generator ===")
print(f"Data source: {data_path}")
print(f"Output: {output_path}")
print(f"Split ratios: {args.train_ratio:.1%} train, {args.val_ratio:.1%} val, {args.test_ratio:.1%} test")
print(f"Random seed: {args.random_seed}")
print()
try:
# Load dataset information
n_samples, labels = load_dataset_info(data_path)
# Create splits
splits = create_stratified_splits(
n_samples=n_samples,
labels=labels,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
test_ratio=args.test_ratio,
random_state=args.random_seed
)
# Prepare metadata
metadata = {
'source_file': str(data_path),
'random_seed': args.random_seed,
'train_ratio': args.train_ratio,
'val_ratio': args.val_ratio,
'test_ratio': args.test_ratio,
'n_samples': n_samples,
'n_unique_labels': len(np.unique(labels))
}
# Save splits (but extract just the splits for DataPreparer compatibility)
splits_only = {
'train': splits['train'],
'val': splits['val'],
'test': splits['test']
}
# Save the splits in DataPreparer-compatible format
with open(output_path, 'w') as f:
json.dump(splits_only, f, indent=2)
print(f"\n✓ Splits successfully generated and saved to: {output_path}")
print(f"\nTo use these splits with DataPreparer:")
print(f"1. Ensure your config file has:")
print(f" data.splits_path: \"{output_path}\"")
print(f"2. The DataPreparer will automatically load these splits when prepare_train_val_test() is called")
# Also save a detailed version with metadata
metadata_output = output_path.with_suffix('.metadata.json')
save_splits_json(splits, metadata_output, metadata)
print(f"3. Detailed metadata saved to: {metadata_output}")
except Exception as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()