Skip to content

Commit 0061177

Browse files
PierreGtchfacebook-github-bot
authored andcommitted
Shuffled flatmap (meta-pytorch#1062)
Summary: Closes meta-pytorch#1043 ### Changes - Adds new datapipe `ShuffledFlatMapper`/`shuffled_flatmap` Pull Request resolved: meta-pytorch#1062 Reviewed By: NivekT Differential Revision: D44176076 Pulled By: ejguan fbshipit-source-id: a75f88fa3d28591d59f758d5affd6b1909841564
1 parent 8a9b963 commit 0061177

5 files changed

Lines changed: 298 additions & 0 deletions

File tree

docs/source/torchdata.datapipes.iter.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ These DataPipes apply the a given function to each element in the DataPipe.
167167
BatchMapper
168168
FlatMapper
169169
Mapper
170+
ShuffledFlatMapper
170171
ThreadPoolMapper
171172

172173
Other DataPipes

test/test_iterdatapipe.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,142 @@ def mul_fn(a, b):
894894
with self.assertRaisesRegex(TypeError, "length relies on the output of its function."):
895895
len(flatmapped_dp)
896896

897+
def test_shuffled_flatmap_iterdatapipe(self):
898+
source_dp = IterableWrapper(list(range(20)))
899+
900+
def fn(e):
901+
return [e, e * 10]
902+
903+
# Tests with buffer_size=1
904+
# In this case, the expected behavior is similar to flatmap
905+
906+
shuffled_flatmapped_dp = source_dp.shuffled_flatmap(fn, buffer_size=1)
907+
expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp]))
908+
909+
self.assertEqual(expected_list, list(shuffled_flatmapped_dp))
910+
911+
# Funtional Test: Specify input_col
912+
tuple_source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)])
913+
914+
# Single input_col
915+
input_col_1_dp = tuple_source_dp.shuffled_flatmap(fn, input_col=1, buffer_size=1)
916+
self.assertEqual(expected_list, list(input_col_1_dp))
917+
918+
# With generator as fn
919+
def gen_fn(e):
920+
yield e
921+
yield e * 10
922+
923+
shuffled_flatmapped_dp = source_dp.shuffled_flatmap(gen_fn, buffer_size=1)
924+
expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp]))
925+
926+
self.assertEqual(expected_list, list(shuffled_flatmapped_dp))
927+
928+
# Multiple input_col
929+
def mul_fn(a, b):
930+
return [a - b, b - a]
931+
932+
input_col_2_dp = tuple_source_dp.shuffled_flatmap(mul_fn, input_col=(0, 2), buffer_size=1)
933+
self.assertEqual(list(itertools.chain(*[(-2, 2) for _ in range(20)])), list(input_col_2_dp))
934+
935+
# shuffled_flatmap with no fn specified
936+
default_dp = tuple_source_dp.shuffled_flatmap(buffer_size=1)
937+
self.assertEqual(list(itertools.chain(*[(n - 1, n, n + 1) for n in range(20)])), list(default_dp))
938+
939+
# shuffled_flatmap with no fn specified, multiple input_col
940+
default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2), buffer_size=1)
941+
self.assertEqual(list(itertools.chain(*[(n - 1, n + 1) for n in range(20)])), list(default_dp))
942+
943+
# shuffled_flatmap with no fn specified, some special input
944+
tuple_source_dp = IterableWrapper([[1, 2, [3, 4]], [5, 6, [7, 8]]])
945+
default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2), buffer_size=1)
946+
self.assertEqual([1, [3, 4], 5, [7, 8]], list(default_dp))
947+
948+
# Reset Test: reset the DataPipe after reading part of it
949+
n_elements_before_reset = 5
950+
res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset)
951+
952+
self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset)
953+
self.assertEqual(expected_list, res_after_reset)
954+
955+
# __len__ Test: length should be len(source_dp)*len(fn->out_shape) which we can't know
956+
with self.assertRaisesRegex(TypeError, "length relies on the output of its function."):
957+
len(shuffled_flatmapped_dp)
958+
959+
# __len__ when no fn specified:
960+
dp = IterableWrapper([[1, 2], [], [3], [4, 5, 6, [7, 8]]])
961+
dp = dp.shuffled_flatmap()
962+
self.assertEqual(len(dp), 7)
963+
964+
# Tests with .set_shuffle(False)
965+
# In this case, the expected behavior is similar to flatmap
966+
967+
shuffled_flatmapped_dp = source_dp.shuffled_flatmap(fn).set_shuffle(False)
968+
expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp]))
969+
970+
self.assertEqual(expected_list, list(shuffled_flatmapped_dp))
971+
972+
# Funtional Test: Specify input_col
973+
tuple_source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)])
974+
975+
# Single input_col
976+
input_col_1_dp = tuple_source_dp.shuffled_flatmap(fn, input_col=1, buffer_size=1)
977+
self.assertEqual(expected_list, list(input_col_1_dp))
978+
979+
# Multiple input_col
980+
input_col_2_dp = tuple_source_dp.shuffled_flatmap(mul_fn, input_col=(0, 2)).set_shuffle(False)
981+
self.assertEqual(list(itertools.chain(*[(-2, 2) for _ in range(20)])), list(input_col_2_dp))
982+
983+
# shuffled_flatmap with no fn specified
984+
default_dp = tuple_source_dp.shuffled_flatmap().set_shuffle(False)
985+
self.assertEqual(list(itertools.chain(*[(n - 1, n, n + 1) for n in range(20)])), list(default_dp))
986+
987+
# shuffled_flatmap with no fn specified, multiple input_col
988+
default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2)).set_shuffle(False)
989+
self.assertEqual(list(itertools.chain(*[(n - 1, n + 1) for n in range(20)])), list(default_dp))
990+
991+
# shuffled_flatmap with no fn specified, some special input
992+
tuple_source_dp = IterableWrapper([[1, 2, [3, 4]], [5, 6, [7, 8]]])
993+
default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2)).set_shuffle(False)
994+
self.assertEqual([1, [3, 4], 5, [7, 8]], list(default_dp))
995+
996+
# Reset Test: reset the DataPipe after reading part of it
997+
n_elements_before_reset = 5
998+
res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset)
999+
1000+
self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset)
1001+
self.assertEqual(expected_list, res_after_reset)
1002+
1003+
# Other tests
1004+
1005+
# Test no empty buffers:
1006+
with self.assertRaises(AssertionError):
1007+
_ = source_dp.shuffled_flatmap(buffer_size=0)
1008+
1009+
# Functional Test: No seed
1010+
consecutive_tuple_source_dp = IterableWrapper([(d, d + 1, d + 2) for d in range(0, 21, 3)])
1011+
shuffled_flatmapped_dp = consecutive_tuple_source_dp.shuffled_flatmap()
1012+
self.assertEqual(set(range(21)), set(shuffled_flatmapped_dp))
1013+
1014+
# Functional Test: With global seed
1015+
torch.manual_seed(123)
1016+
shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap()
1017+
res = list(shuffled_flatmapped_dp)
1018+
torch.manual_seed(123)
1019+
self.assertEqual(list(shuffled_flatmapped_dp), res)
1020+
1021+
# Functional Test: Set seed
1022+
shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap().set_seed(123)
1023+
res = list(shuffled_flatmapped_dp)
1024+
shuffled_flatmapped_dp.set_seed(123)
1025+
self.assertEqual(list(shuffled_flatmapped_dp), res)
1026+
1027+
# Reset Test:
1028+
shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap()
1029+
n_elements_before_reset = 5
1030+
res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset)
1031+
self.assertEqual(5, len(res_before_reset))
1032+
8971033
def test_round_robin_demux_iterdatapipe(self):
8981034
source_dp = IterableWrapper(list(range(23)))
8991035
with self.assertRaisesRegex(ValueError, "Expected `num_instaces`"):

test/test_serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def test_serializable(self):
197197
(iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
198198
(iterdp.Enumerator, None, (2,), {}),
199199
(iterdp.FlatMapper, None, (_fake_fn_ls,), {}),
200+
(iterdp.ShuffledFlatMapper, None, (_fake_fn_ls,), {"buffer_size": 1}),
200201
(iterdp.Flattener, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}),
201202
(iterdp.FSSpecFileLister, ".", (), {}),
202203
(iterdp.FSSpecFileOpener, None, (), {}),
@@ -363,6 +364,7 @@ def test_serializable_with_dill(self):
363364
unpicklable_datapipes: List = [
364365
(iterdp.BatchMapper, (lambda batch: [d + 1 for d in batch], 2), {}),
365366
(iterdp.FlatMapper, (lambda x: [x, x],), {}),
367+
(iterdp.ShuffledFlatMapper, (lambda x: [x, x],), {"buffer_size": 1}),
366368
(iterdp.IterKeyZipper, (ref_idp, lambda x: x, None, True, 100), {}),
367369
(iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}),
368370
(iterdp.OnDiskCacheHolder, (lambda x: x,), {}),

torchdata/datapipes/iter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
DropperIterDataPipe as Dropper,
7272
FlatMapperIterDataPipe as FlatMapper,
7373
FlattenIterDataPipe as Flattener,
74+
ShuffledFlatMapperIterDataPipe as ShuffledFlatMapper,
7475
SliceIterDataPipe as Slicer,
7576
ThreadPoolMapperIterDataPipe as ThreadPoolMapper,
7677
)
@@ -209,6 +210,7 @@
209210
"ShardExpander",
210211
"ShardingFilter",
211212
"ShardingRoundRobinDispatcher",
213+
"ShuffledFlatMapper",
212214
"Shuffler",
213215
"Slicer",
214216
"StreamReader",

torchdata/datapipes/iter/transform/callable.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
import asyncio
88
import inspect
9+
import random
910
import warnings
1011
from collections import deque
1112
from concurrent import futures
1213

1314
from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union
1415

16+
import torch
1517
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col
1618
from torchdata.datapipes import functional_datapipe
1719
from torchdata.datapipes.iter import IterDataPipe
@@ -170,6 +172,161 @@ def __len__(self) -> int:
170172
raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.")
171173

172174

175+
@functional_datapipe("shuffled_flatmap")
176+
class ShuffledFlatMapperIterDataPipe(IterDataPipe):
177+
r"""
178+
Applies a function over each item from the source DataPipe,
179+
then collects the iterables returned in a buffer,
180+
then, at every iteration, chooses at random one of the iterables in the buffer
181+
and yields one item from this iterable (functional name: ``shuffled_flatmap``).
182+
183+
When the buffer is full, the DataPipe will begin to yield elements from iterables within the buffer.
184+
New iterables will be added to the buffer once the existing ones run out of elements.
185+
Note:
186+
The output from ``fn`` must be an Iterable. Otherwise, an error will be raised.
187+
If ``fn`` is ``None``, source DataPipe will be just flattened vertically, provided that items can be unpacked.
188+
189+
Args:
190+
datapipe: Source IterDataPipe
191+
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
192+
input_col: Index or indices of data which ``fn`` is applied, such as:
193+
194+
- ``None`` as default to apply ``fn`` to the data directly.
195+
- Integer(s) is/are used for list/tuple.
196+
- Key(s) is/are used for dict.
197+
buffer_size: the max number of iterables this DataPipe can hold at a time (default to ``100``)
198+
199+
Example:
200+
>>> from torchdata.datapipes.iter import IterableWrapper
201+
>>> source_dp = IterableWrapper([[1, 2, 3, 4], 'abcd', 'ABCD'])
202+
>>> shuffled_flatmapped_dp = source_dp.shuffled_flatmap(buffer_size=2)
203+
>>> list(shuffled_flatmapped_dp)
204+
['a', 'b', 'c', 1, 'd', 'A', 'B', 'C', 2, 'D', 3, 4]
205+
>>>
206+
>>> # To shuffle all the elements, you can combine `shuffled_flatmap` with `in_batch_shuffle` like this:
207+
>>> fully_shuffled_flatmapped_dp = source_dp.in_batch_shuffle()
208+
>>> fully_shuffled_flatmapped_dp = fully_shuffled_flatmapped_dp.shuffled_flatmap()
209+
>>> list(fully_shuffled_flatmapped_dp)
210+
['b', 3, 'c', 'd', 'C', 'A', 'a', 2, 'B', 'D', 4, 1]
211+
"""
212+
datapipe: IterDataPipe
213+
fn: Optional[Callable]
214+
buffer_size: int
215+
_buffer: List[Iterator]
216+
_enabled: bool
217+
_seed: Optional[int]
218+
_rng: random.Random
219+
_no_op_fn: bool = False
220+
221+
def __init__(
222+
self, datapipe: IterDataPipe, fn: Optional[Callable] = None, input_col=None, buffer_size: int = 100
223+
) -> None:
224+
super().__init__()
225+
self._buffer = []
226+
self.datapipe = datapipe
227+
228+
if fn is None:
229+
fn = _no_op_fn
230+
self._no_op_fn = True
231+
_check_unpickable_fn(fn)
232+
self.fn = fn # type: ignore[assignment]
233+
self.input_col = input_col
234+
validate_input_col(fn, input_col)
235+
236+
assert buffer_size > 0, "buffer_size should be larger than 0"
237+
self.buffer_size = buffer_size
238+
self._enabled = True
239+
self._seed = None
240+
self._rng = random.Random()
241+
242+
def set_shuffle(self, shuffle=True):
243+
self._enabled = shuffle
244+
return self
245+
246+
def set_seed(self, seed: int):
247+
self._seed = seed
248+
return self
249+
250+
def reset(self) -> None:
251+
self._buffer = []
252+
if self._enabled:
253+
if self._seed is None:
254+
self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
255+
self._rng.seed(self._seed)
256+
self._seed = None
257+
258+
def _apply_fn(self, data):
259+
if self.input_col is None:
260+
return self.fn(data) # type: ignore[misc]
261+
elif isinstance(self.input_col, (list, tuple)):
262+
args = tuple(data[col] for col in self.input_col)
263+
return self.fn(*args) # type: ignore[misc]
264+
else:
265+
return self.fn(data[self.input_col]) # type: ignore[misc]
266+
267+
def __iter__(self) -> Iterator[T_co]:
268+
if not self._enabled: # equivalent to flatmap
269+
for x in self.datapipe:
270+
yield from self._apply_fn(x)
271+
else:
272+
idx = self._rng.randint(0, self.buffer_size - 1)
273+
for x in self.datapipe:
274+
while len(self._buffer) == self.buffer_size:
275+
try:
276+
yield next(self._buffer[idx])
277+
idx = self._rng.randint(0, self.buffer_size - 1)
278+
except StopIteration:
279+
self._buffer.pop(idx)
280+
self._buffer.append(iter(self._apply_fn(x)))
281+
while self._buffer:
282+
try:
283+
idx = self._rng.randint(0, len(self._buffer) - 1)
284+
yield next(self._buffer[idx])
285+
except StopIteration:
286+
self._buffer.pop(idx)
287+
288+
def __len__(self) -> int:
289+
if self._no_op_fn:
290+
return sum(map(len, self.datapipe))
291+
raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.")
292+
293+
def __getstate__(self):
294+
state = (
295+
self.datapipe,
296+
self.fn,
297+
self.input_col,
298+
self.buffer_size,
299+
self._buffer,
300+
self._enabled,
301+
self._seed,
302+
self._rng.getstate(),
303+
self._valid_iterator_id,
304+
self._number_of_samples_yielded,
305+
)
306+
if IterDataPipe.getstate_hook is not None:
307+
return IterDataPipe.getstate_hook(state)
308+
return state
309+
310+
def __setstate__(self, state):
311+
(
312+
self.datapipe,
313+
self.fn,
314+
self.input_col,
315+
self.buffer_size,
316+
self._buffer,
317+
self._enabled,
318+
self._seed,
319+
rng_state,
320+
self._valid_iterator_id,
321+
self._number_of_samples_yielded,
322+
) = state
323+
self._rng = random.Random()
324+
self._rng.setstate(rng_state)
325+
326+
def __del__(self):
327+
self._buffer.clear()
328+
329+
173330
@functional_datapipe("drop")
174331
class DropperIterDataPipe(IterDataPipe[T_co]):
175332
r"""

0 commit comments

Comments
 (0)