-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmissing_data_provider.py
More file actions
355 lines (289 loc) · 14.8 KB
/
missing_data_provider.py
File metadata and controls
355 lines (289 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import math
from functools import reduce
from typing import Union
import numpy as np
import torch
from torch.distributions.beta import Beta
from torch.utils.data import Dataset
__all__ = ['MissingDataProvider',
'_generate_patterns',
'_random_weights_for_patterns']
class MissingDataProvider(Dataset):
"""
Generates missing data in the given dataset.
The missingness is denoted by adding additional tensor M,
whose size is the same as X.
Args:
dataset (torch.utils.data.Dataset): Fully-observed PyTorch dataset
target_idx (int): If the dataset returns tuples, then this should be the index
of the target data in the tuple for which the missing mask is added.
total_miss (float): Total fraction of values to be made missing
miss_type (str): The type of missingness to be generated
- if `patterns`, then the provided patterns are used
- if `MCAR`, then generates a uniformly distributed mask on the whole data
- if `MAR`, `MNAR` (or `NMAR`), then generates #max_patterns random patterns
and generates the missing masks using the patterns
patterns (torch.Tensor or np.ndarray): Patterns to be used if miss_type == `patterns`
rel_freqs (torch.Tensor or np.ndarray): Relative frequencies of the given patterns
weights (torch.Tensor or np.ndarray): Mechanism to be used with the given patterns
balances (torch.Tensor or np.ndarray): If the weights were previously fitted on some
data then the fitted balances can be set too.
max_patterns (int): Maximum number of patterns to generate for the chosen miss_type
if the miss_type is `MAR`, `MNAR`, or `NMAR` (but not for MCAR).
should_fit_to_data (bool): Whether the weights should be fitted to the data and new balance
terms should be computed. Most often you want this to be set to True, unless you're
reusing weights and balances fitted onto another say, e.g. fitted on training data,
and reusing on test data. In this case weights and balances should be provided.
rand_generator (torch.Generator): (optional) PyTorch random number generator.
"""
def __init__(self,
dataset: Dataset,
target_idx: int = 0,
total_miss: float = 0.00,
miss_type: str = 'MCAR',
patterns: Union[np.ndarray, torch.Tensor] = None,
rel_freqs: Union[np.ndarray, torch.Tensor] = None,
weights: Union[np.ndarray, torch.Tensor] = None,
balances: Union[np.ndarray, torch.Tensor] = None,
max_patterns: int = None,
should_fit_to_data: bool = True,
rand_generator: torch.Generator = None):
super().__init__()
self.dataset = dataset
self.target_idx = target_idx
self.total_miss = total_miss
self.miss_type = miss_type
self.should_fit_to_data = should_fit_to_data
self.max_patterns = max_patterns
self.patterns = patterns
self.rel_freqs = rel_freqs
self.weights = weights
self.balances = balances
self._validate_args()
# Initialise pseudo-random number generator
self.rand_generator = (rand_generator
if rand_generator is not None
else torch.Generator())
# Any preparations that need to be done before sampling masks
self.prepare_prerequisites()
# Sample the missingness mask
self.init_miss_mask()
def prepare_prerequisites(self):
if self.miss_type == 'MCAR':
# No preparation needed for MCAR
return
# Get target data
data = self._get_target_data()
# If the type is one of the below then we generate random missingness patterns
if self.miss_type in ('MNAR', 'NMAR', 'MAR'):
# When generating patterns we generally want a little higher missing value
# fraction in the patterns, so that we can have some completely-observed
# cases too.
pattern_miss = min(Beta(2, 5).sample().item(), 1)
pattern_miss = (1 - self.total_miss) * pattern_miss + self.total_miss
self.patterns, self.rel_freqs = _generate_patterns(
self.max_patterns,
D=data.shape[-1],
total_miss=pattern_miss,
rand_generator=self.rand_generator)
self.weights = _random_weights_for_patterns(
self.patterns,
miss_mech=self.miss_type,
dtype=data.dtype,
rand_generator=self.rand_generator)
# Convert total % of missing values to % of incomplete samples
self.incomp_frac = self._incomplete_sample_fraction(
self.patterns,
self.rel_freqs,
self.total_miss,
data.shape)
# Choose samples to be made incomplete
self.incomp_idxs, _ = \
self._split_comp_and_incomp_idxs(data, self.incomp_frac)
if self.should_fit_to_data:
self.fit_to_data(data[self.incomp_idxs, :])
def init_miss_mask(self):
# Get all target data
data = self._get_target_data()
if self.miss_type == 'MCAR':
self.miss_mask = self._uniform_mask(data)
elif self.miss_type in ('patterns', 'MNAR', 'NMAR', 'MAR'):
self.miss_mask = self._sample_miss_mask(data)
else:
raise Exception('No such missingness mechanism type allowed:'
f' {self.miss_type}!')
def __getitem__(self, idx):
data = self.dataset[idx]
miss_mask = self.miss_mask[idx]
if isinstance(data, tuple):
# Insert missingness mask after the target_idx tensor to which it corresponds
data = (data[:self.target_idx+1]
+ (miss_mask,)
+ data[self.target_idx+1:])
else:
data = (data, miss_mask)
return data
def __len__(self):
return len(self.dataset)
def fit_to_data(self, incomp_data):
""" Fits weights and balance terms
"""
# Compute the scores for each data point and its pattern's weight
# other scores are set to zero
wx = incomp_data @ self.weights.T
# Where score is always 0, the pattern is MCAR
mcar_dims = torch.all(wx == 0, dim=0)
# Compute the standardised z-score for each pattern
score_std, score_mean = torch.std_mean(wx, unbiased=True, dim=0)
# Prevent division by zero for MCAR patterns
score_std[mcar_dims] = 1.
wx_standardised = (wx - score_mean) / score_std
# Compute the balance term for each pattern
b = torch.log(self.rel_freqs) - torch.mean(wx_standardised, axis=0)
self.weights = self.weights / score_std[:, None]
self.balances = b - (score_mean / score_std)
def _sample_miss_mask(self, data):
incomp_data = data[self.incomp_idxs, :]
# Sample missingness masks
incomp_mask_idxs = self._sample_miss_mask_idx(incomp_data,
self.weights,
self.balances)
incomp_masks = self.patterns[incomp_mask_idxs]
# Create a mask of all 1s for the fully-observed data-points
all_masks = torch.ones_like(data, dtype=torch.bool)
all_masks[self.incomp_idxs] = incomp_masks
return all_masks
def _get_target_data(self, idx=slice(None)):
# NOTE: this won't work with large datasets that do not fit into memory
data = self.dataset[idx]
if isinstance(data, tuple):
# Get the data for which the missing masks are generated
data = data[self.target_idx]
if isinstance(data, np.ndarray):
data = torch.tensor(data)
return data
def _validate_args(self):
assert self.miss_type in ('patterns', 'MCAR', 'MAR', 'MNAR', 'NMAR'),\
f'Invalid missingness mechanism type: {self.miss_type}!'
assert 0 <= self.total_miss <= 1,\
f'Invalid total missingness: {self.total_miss:.2f}!'
if self.miss_type == 'patterns':
assert None not in (self.patterns, self.rel_freqs, self.weights),\
'For miss_type==patterns, patterns, rel_freqs, and weights must be provided!'
if self.patterns is not None:
assert (self.patterns.shape[0] == self.rel_freqs.shape[0]
and self.patterns.shape[0] == self.weights.shape[0]
and (self.balances is None
or self.patterns.shape[0] == self.balances.shape[0])),\
'Shapes of patterns, rel_freqs, or weights (and balances) do not match!'
if self.miss_type == 'patterns' and not self.should_fit_to_data:
assert self.weights is not None and self.balances is not None,\
('For miss_type==`patterns` and should_fit_to_data==False, the fitted weights '
'and balances should be provided!')
if self.miss_type in ('MCAR', 'MAR', 'MNAR', 'NMAR'):
assert self.should_fit_to_data, \
'If generating missingness patterns, then should_fit_to_data should be set to True!'
if self.miss_type in ('MAR', 'MNAR', 'NMAR'):
assert self.max_patterns is not None, \
'If generating missingness patterns, then max_patterns should be set given!'
def _uniform_mask(self, data):
# Works for PyTorch and Numpy
total_values = reduce(lambda x, y: x*y, data.shape, 1)
# Generate appropriate number of missing values
miss_mask = torch.ones(total_values, dtype=torch.bool)
miss_mask[:int(self.total_miss*total_values)] = 0
# Randomise mask
rand_idx = torch.randperm(total_values, generator=self.rand_generator)
miss_mask = miss_mask[rand_idx]
return miss_mask.reshape_as(data)
def _incomplete_sample_fraction(self, patterns, rel_freqs,
total_miss, data_shape):
total_values = reduce(lambda x, y: x*y, data_shape, 1)
miss_values = total_miss * total_values
# The number of incomplete samples for each pattern
C = miss_values * rel_freqs / torch.sum(~patterns, dim=1)
# Total number of incomplete samples for all patterns
C = torch.sum(C)
# assert C < data_shape[0],\
# ('The calculated incomplete sample fraction is greater than the '
# 'number of samples. This means that the patterns and relative '
# 'frequencies are not compatible with the requested total '
# 'missingness fraction.')
C = min(C.item(), data_shape[0])
# The incomplete fraction of all data points
return C / data_shape[0]
def _split_comp_and_incomp_idxs(self, data, frac_incomp_samples):
# The total number of incomplete data points
total_incomp = int(math.floor(frac_incomp_samples * data.shape[0]))
# Randomly split the data into incomplete and complete subsets
all_idx = torch.randperm(data.shape[0], generator=self.rand_generator)
return all_idx[:total_incomp], all_idx[total_incomp:]
def _sample_miss_mask_idx(self, incomp_data, weights, balances):
scores = incomp_data @ weights.T + balances
probs = torch.nn.functional.softmax(scores, dim=-1)
return torch.multinomial(probs, 1,
replacement=False,
generator=self.rand_generator).squeeze()
# TODO: note that this can run into an infinite loop with poor choice of arguments
# TODO: add a check
def _generate_patterns(max_patterns, D, total_miss, rand_generator=None):
"""
Generate missingness patterns as binary masks. 1 is observed and 0 is missing.
Args:
max_patterns (int): maximum number of binary missingness patterns
D (int): dimensionality of each patterns
total_miss (float): The total fraction of missing values in the
patterns, between 0 and 1.
rand_generator (torch.Generator): (optional) PyTorch random number generator.
"""
rand_generator = (rand_generator if rand_generator is not None
else torch.Generator())
rand_generator.seed()
# Create an array with the desired fraction of missing values
total_vals = max_patterns*D
miss_vals = int(total_vals*total_miss)
def gen_patterns():
patterns = torch.ones((total_vals, ), dtype=torch.bool)
patterns[:miss_vals] = 0.
# Shuffle the array and reshape to the desired pattern shape
rand_idx = torch.randperm(total_vals, generator=rand_generator)
patterns = patterns[rand_idx]
patterns = patterns.reshape(max_patterns, D)
return patterns
patterns = gen_patterns()
# NOTE: this may never finish if there is no solution so
# an assertion is needed
# We want to prevent fully-patterns (all ones), so generate until
# we get one.
while torch.any(torch.all(patterns, dim=-1)).item():
patterns = gen_patterns()
# Only return unique patterns and get their relative frequency
# TODO: maybe not unique?
patterns, rel_freqs = torch.unique(patterns, dim=0, return_counts=True)
rel_freqs = rel_freqs.float()
rel_freqs = rel_freqs / rel_freqs.sum()
return patterns, rel_freqs
def _random_weights_for_patterns(patterns, miss_mech,
dtype=torch.float, rand_generator=None):
"""
Generate a random weight matrix for each pattern. Weight values in [-1, 1)
Args:
patterns (torch.Tensor): Patterns for which to generate random weights
miss_mech (str): MNAR (zeros where patterns==1) or MAR (zeros where patterns==0)
dtype (torch.dtype): Data type of the weights
rand_generator (torch.Generator): (optional) PyTorch random number generator.
"""
rand_generator = (rand_generator if rand_generator is not None
else torch.Generator())
# Generate random weights in [-1, 1)
weights = torch.rand(*(patterns.shape), dtype=dtype)*2-1
if miss_mech in ('MNAR', 'NMAR'):
# Make sure the patterns depend on missing variables only
weights *= ~patterns
elif miss_mech == 'MAR':
# Make sure the patterns depend on observed variables only
weights *= patterns
else:
raise Exception('No such missingness mechanism type allowed:'
f' {miss_mech}!')
return weights