Skip to content

Commit 01fc762

Browse files
ejguanfacebook-github-bot
authored andcommitted
Fix lint for SequentialReadingService and unify the API for process_reset_fn (meta-pytorch#983)
Summary: This is the follow up for 807db8f ### Changes - Fix lint - Fix checkpoint result - Add inline doc - Unify the API for process_reset_fn and dispatch_process_reset_fn After this PR, I will land a test case and tutorial for SequentialRS Pull Request resolved: meta-pytorch#983 Reviewed By: NivekT Differential Revision: D42979577 Pulled By: ejguan fbshipit-source-id: 1cb601c20701cf44a1f7b3afe972a9436cdb9465
1 parent 807db8f commit 01fc762

4 files changed

Lines changed: 51 additions & 16 deletions

File tree

docs/source/dataloader2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ ReadingService
3333
DistributedReadingService
3434
MultiProcessingReadingService
3535
PrototypeMultiProcessingReadingService
36+
SequentialReadingService
3637

3738
Each ``ReadingServices`` would take the ``DataPipe`` graph and rewrite it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes. For more detail about those features, please refer to `the documentation <reading_service.html>`_.
3839

torchdata/dataloader2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MultiProcessingReadingService,
1414
PrototypeMultiProcessingReadingService,
1515
ReadingServiceInterface,
16-
SequentialReadingService
16+
SequentialReadingService,
1717
)
1818
from torchdata.dataloader2.shuffle_spec import ShuffleSpec
1919

torchdata/dataloader2/reading_service.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def finalize(self) -> None:
6464
"""
6565
pass
6666

67-
def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> Optional[Callable[[DataPipe], DataPipe]]:
67+
def initialize_iteration(
68+
self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None
69+
) -> Optional[Callable[[DataPipe], DataPipe]]:
6870
r"""
6971
``ReadingService`` spins up service for an epoch. Called at the beginning
7072
of every time getting ``DataLoader2`` iterator.
@@ -73,6 +75,11 @@ def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Opt
7375
seed_generator: SeedGenerator object created and managed by DataLoader2. As the single
7476
source of randomness, it will governs the determinism for all of random operations
7577
with the graph of DataPipes.
78+
iter_reset_fn: Optional reset function from the prior ``ReadingServcie``
79+
when ``SequentialReadingService`` chains multiple ``ReadingServices``
80+
81+
Returns:
82+
A new ``iter_reset_fn`` to be used by subseqeuent ``ReadingService``
7683
7784
Example:
7885
MultiProcessingReadingService starts setting worker seeds per process and prefetching
@@ -303,7 +310,9 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
303310

304311
return self._end_datapipe # type: ignore[return-value]
305312

306-
def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> Optional[Callable[[DataPipe], DataPipe]]:
313+
def initialize_iteration(
314+
self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None
315+
) -> Optional[Callable[[DataPipe], DataPipe]]:
307316
if self._pg is not None:
308317
shared_seed_int = dist_share_seed(seed_generator.generate_shared_seed(), self._pg)
309318
seed_generator.seed(shared_seed_int)
@@ -318,7 +327,9 @@ def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Opt
318327
# Stop prefetching first
319328
self._main_prefetch_datapipe.reset() # type: ignore[union-attr]
320329
# Send the shared seed to subprocesses
321-
call_on_epoch_reset = partial(process_reset_fn, custom_reset_fn=self.worker_reset_fn, custom_dispatch_process_reset_fn=iter_reset_fn)
330+
call_on_epoch_reset = partial(
331+
process_reset_fn, iter_reset_fn=iter_reset_fn, custom_reset_fn=self.worker_reset_fn
332+
)
322333
assert self._worker_consumer_datapipe is not None
323334
self._worker_consumer_datapipe.reset_epoch(call_on_epoch_reset, seed_generator)
324335
# In-process (num_workers == 0)
@@ -328,6 +339,7 @@ def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Opt
328339
# (random, torch and numpy), if users have already seeded them in the main process
329340
# TODO(ejguan): This should be fixed by adding a method to isolate global RNGs
330341
pass
342+
return None
331343

332344
def __del__(self):
333345
self.finalize()
@@ -485,7 +497,9 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
485497
self._datapipe = datapipe
486498
return datapipe
487499

488-
def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
500+
def initialize_iteration(
501+
self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None
502+
) -> Optional[Callable[[DataPipe], DataPipe]]:
489503
r"""
490504
Shares the same seed from rank 0 to other ranks across the distributed processes
491505
and apply the random seed to the ``DataPipe`` graph.
@@ -496,6 +510,7 @@ def initialize_iteration(self, seed_generator: SeedGenerator) -> None:
496510
seed_generator.seed(shared_seed)
497511
seed_generator = seed_generator.spawn(self._rank, inplace=True)
498512
set_graph_random_seed(self._datapipe, seed_generator)
513+
return None
499514

500515
def __del__(self):
501516
self.finalize()
@@ -509,7 +524,7 @@ def finalize(self) -> None:
509524
self._pg = None
510525

511526

512-
class SequentialReadingService(ReadingServiceInterface):
527+
class SequentialReadingService(CheckpointableReadingServiceInterface):
513528
def __init__(self, *reading_services):
514529
self.reading_services = reading_services
515530

@@ -525,10 +540,15 @@ def finalize(self) -> None:
525540
rs.finalize()
526541

527542
# Sequential Order
528-
def initialize_iteration(self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> Optional[Callable[[DataPipe], DataPipe]]:
543+
def initialize_iteration(
544+
self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None
545+
) -> Optional[Callable[[DataPipe], DataPipe]]:
529546
chained_iter_reset_fn = iter_reset_fn
530547
for rs in self.reading_services:
531-
chained_iter_reset_fn = rs.initialize_iteration(seed_generator=seed_generator, iter_reset_fn=chained_iter_reset_fn)
548+
chained_iter_reset_fn = rs.initialize_iteration(
549+
seed_generator=seed_generator, iter_reset_fn=chained_iter_reset_fn
550+
)
551+
return None
532552

533553
# Reversed Order
534554
def finalize_iteration(self) -> None:
@@ -540,9 +560,12 @@ def checkpoint(self) -> bytes:
540560
states = []
541561
for rs in self.reading_services:
542562
states.append(rs.checkpoint())
563+
return b"\n".join(states)
543564

544565
# Sequential Order, to align with initialize
545-
def restore(self, datapipe, serialized_state) -> DataPipe:
546-
for rs, state in zip(self.reading_service, serialized_state):
566+
def restore(self, datapipe, serialized_state: bytes) -> DataPipe:
567+
states = serialized_state.split(b"\n")
568+
assert len(states) == len(self.reading_services)
569+
for rs, state in zip(self.reading_services, states):
547570
datapipe = rs.restore(datapipe, state)
548571
return datapipe

torchdata/dataloader2/utils/worker.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def dispatch_process_reset_fn(
116116
datapipe: DataPipe,
117117
worker_info: WorkerInfo,
118118
seed_generator: SeedGenerator,
119-
custom_dispatch_process_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None,
119+
iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None,
120+
custom_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None,
120121
) -> DataPipe:
121122
r"""
122123
Based on the distributed shared random seed, this function is used to set the random state
@@ -131,8 +132,12 @@ def dispatch_process_reset_fn(
131132
dps = list_dps(graph)
132133
set_datapipes_seed(dps, seed_generator=seed_generator, distributed_shared=True)
133134

134-
if custom_dispatch_process_reset_fn is not None:
135-
datapipe = custom_dispatch_process_reset_fn(datapipe)
135+
if iter_reset_fn is not None:
136+
datapipe = iter_reset_fn(datapipe)
137+
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
138+
139+
if custom_reset_fn is not None:
140+
datapipe = custom_reset_fn(datapipe, worker_info, seed_generator)
136141
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
137142

138143
return datapipe
@@ -142,8 +147,8 @@ def process_reset_fn(
142147
datapipe: DataPipe,
143148
worker_info: WorkerInfo,
144149
seed_generator: SeedGenerator,
145-
custom_reset_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
146-
custom_dispatch_process_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None,
150+
iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None,
151+
custom_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None,
147152
) -> DataPipe:
148153
r"""
149154
Based on the distributed shared random seed and worker id, this function is used to
@@ -160,14 +165,20 @@ def process_reset_fn(
160165
# Only send the reset epoch message once
161166
if worker_info.worker_id == 0:
162167
# Use WorkerInfo(1, 0)
163-
dispatch_reset_fn = partial(dispatch_process_reset_fn, custom_dispatch_process_reset_fn=custom_dispatch_process_reset_fn)
168+
dispatch_reset_fn = partial(
169+
dispatch_process_reset_fn, iter_reset_fn=iter_reset_fn, custom_reset_fn=custom_reset_fn
170+
)
164171
dispatch_process_consumer_dp.reset_epoch(dispatch_reset_fn, seed_generator)
165172

166173
# Set global random states
167174
_set_global_random_state(seed_generator)
168175

169176
set_graph_random_seed(datapipe, seed_generator)
170177

178+
if iter_reset_fn is not None:
179+
datapipe = iter_reset_fn(datapipe)
180+
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
181+
171182
if custom_reset_fn is not None:
172183
datapipe = custom_reset_fn(datapipe, worker_info, seed_generator)
173184
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))

0 commit comments

Comments
 (0)