Skip to content

Commit 2ca1fa6

Browse files
ejguanfacebook-github-bot
authored andcommitted
Add support to keep non-replicable DataPipe in the main process (meta-pytorch#950)
Summary: This is useful when `fullsync` is in the pipeline and we don't want to make this DataPipe running in the worker process ### Changes - Change the function names that is dispatching-related to `dispatching_xxx` - Make `fullsync` DataPipe non-replicable - Add `_find_replicable_branches` to find the last DataPipe prior to any non-replicable DataPipe - Add graph tests - In `PrototypeMultiprocessingReadingService`, make sure only `replicable_datapipe` sent to worker process. And, replace the `replicable_datapipe` with the `worker_consumer_datapipe`. Pull Request resolved: meta-pytorch#950 Reviewed By: wenleix, NivekT, Miiira Differential Revision: D42617776 Pulled By: ejguan fbshipit-source-id: 1138203507934b089025e290597b473ef9be32bb
1 parent 29d7a5e commit 2ca1fa6

8 files changed

Lines changed: 324 additions & 104 deletions

File tree

docs/source/dataloader2.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Dynamic sharding is achieved by ``PrototypeMultiProcessingReadingService`` and `
4141

4242
- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``.
4343

44-
- ``sharding_round_robin_dispatch``: When there is any non-replicable ``DataPipe`` (``sharding_round_robin_dispatch``) in the pipeline, a dispatching process will be created to load data from the non-replicable ``DataPipe`` and distributed data to the subsequent worker processes.
44+
- ``sharding_round_robin_dispatch``: When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch will be treated as a non-replicable branch. Then, a single dispatching process will be created to load data from the non-repliable branch and distributed data to the subsequent worker processes.
4545

4646
The following is an example of having two types of sharding strategies in the pipeline.
4747

test/dataloader2/test_dataloader2.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,17 @@ def _x_mult_2(d):
360360
return d * 2
361361

362362

363+
class NonReplicableDataPipe(IterDataPipe):
364+
def __init__(self, datapipe):
365+
self.datapipe = datapipe
366+
367+
def __iter__(self):
368+
yield from self.datapipe
369+
370+
def is_replicable(self):
371+
return False
372+
373+
363374
class PrototypeMultiProcessingReadingServiceTest(TestCase):
364375
@staticmethod
365376
def _worker_init_fn(datapipe, worker_info):
@@ -592,6 +603,37 @@ def test_dispatching_worker_determinism(self, ctx):
592603
dl.seed(321)
593604
self.assertNotEqual(res, list(dl) + list(dl))
594605

606+
@mp_ctx_parametrize
607+
def test_non_replicable_datapipe(self, ctx) -> None:
608+
r"""
609+
For the pipeline with non-replicable DataPipe, make sure
610+
the DataPipe remains in the main process.
611+
"""
612+
dp: IterDataPipe = IterableWrapper(range(100))
613+
dp = dp.shuffle().sharding_filter()
614+
dp = dp.batch(2)
615+
non_rep_dp = NonReplicableDataPipe(dp)
616+
617+
rs = PrototypeMultiProcessingReadingService(
618+
num_workers=2,
619+
multiprocessing_context=ctx,
620+
)
621+
dl = DataLoader2(non_rep_dp, reading_service=rs)
622+
623+
torch.manual_seed(123)
624+
it = iter(dl)
625+
# Validate NonReplicableDataPipe still in the main process
626+
non_rep_dp = dl.reading_service._end_datapipe._datapipe
627+
self.assertEqual(type(non_rep_dp), NonReplicableDataPipe)
628+
629+
res = list(it) + list(dl)
630+
631+
torch.manual_seed(123)
632+
self.assertEqual(res, list(dl) + list(dl))
633+
634+
torch.manual_seed(321)
635+
self.assertNotEqual(res, list(dl) + list(dl))
636+
595637

596638
instantiate_parametrized_tests(PrototypeMultiProcessingReadingServiceTest)
597639

test/test_graph.py

Lines changed: 153 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import types
78
import unittest
89

910
from typing import Dict, Iterator, List, Tuple, TypeVar
@@ -17,11 +18,12 @@
1718

1819
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, ReadingServiceInterface
1920
from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps
21+
from torchdata.dataloader2.graph.utils import _find_replicable_branches
2022
from torchdata.dataloader2.random import SeedGenerator
2123
from torchdata.dataloader2.utils.dispatch import (
2224
_DummyIterDataPipe,
23-
find_lca_non_replicable_dp,
24-
find_replicable_branches,
25+
find_lca_round_robin_sharding_dp,
26+
find_non_dispatching_branches,
2527
)
2628
from torchdata.datapipes.iter import IterableWrapper, Mapper, ShardingRoundRobinDispatcher
2729
from torchdata.datapipes.utils import to_graph
@@ -262,15 +264,20 @@ def test_multiprocessing_reading_service(self) -> None:
262264
self.assertEqual(d1, d2)
263265

264266

265-
def make_dp_non_replicable(graph, datapipe):
266-
non_rep_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING)
267-
return replace_dp(graph, datapipe, non_rep_dp), non_rep_dp
267+
def insert_round_robin_sharding(graph, datapipe):
268+
dispatch_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING)
269+
return replace_dp(graph, datapipe, dispatch_dp), dispatch_dp
268270

269271

270272
def replace_by_dummy(graph, datapipe):
271273
return replace_dp(graph, datapipe, _DummyIterDataPipe())
272274

273275

276+
def make_non_replicable_dp(datapipe):
277+
datapipe.is_replicable = types.MethodType(lambda self: False, datapipe)
278+
return datapipe
279+
280+
274281
class TestNonReplicableDataPipe(expecttest.TestCase):
275282
def _make_dp(self):
276283
r"""
@@ -279,21 +286,22 @@ def _make_dp(self):
279286
- multi-branch pipeline
280287
- pipeline that has circurlar references
281288
282-
single_br_dp ---------------------------------
283-
ch1 \
284-
/ \ \
285-
multi_br_dp ---------> -> fork_zip_dp -> end_dp ->
286-
\ / /
287-
<------- ch2 /
288-
/ \ /
289-
cir_br_dp -> cir_map_dp ----------------------
289+
single_br_dp -------------------------------------
290+
ch1 \
291+
/ \ \
292+
multi_br_dp -->forker_dp--> -> fork_zip_dp -> end_dp ->
293+
\ / /
294+
<------- ch2 /
295+
/ \ /
296+
cir_br_dp -> cir_map_dp --------------------------
290297
"""
291298
# Single-branch
292299
single_br_dp = IterableWrapper(list(range(10)))
293300

294301
# Multi-branch
295302
multi_br_dp = IterableWrapper(list(range(10)))
296303
ch1, ch2 = multi_br_dp.fork(2)
304+
forker_dp = ch1.main_datapipe
297305
fork_zip_dp = ch1.zip(ch2)
298306

299307
# Circular-branch
@@ -304,93 +312,188 @@ def _make_dp(self):
304312

305313
end_dp = single_br_dp.zip(fork_zip_dp, cir_map_dp)
306314
graph = traverse_dps(end_dp)
307-
return single_br_dp, multi_br_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph
315+
return single_br_dp, multi_br_dp, forker_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph
308316

309-
def test_single_non_replicable_dp(self):
317+
def test_single_round_robin_sharding_dp(self):
310318
single_br_dp, *_, graph = self._make_dp()
311-
graph, single_br_dp = make_dp_non_replicable(graph, single_br_dp)
312-
self.assertEqual(find_lca_non_replicable_dp(graph), single_br_dp)
319+
graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp)
320+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), single_br_dp)
313321

314322
# The same non-shardable DataPipe on both branches
315323
_, multi_br_dp, *_, graph = self._make_dp()
316-
graph, multi_br_dp = make_dp_non_replicable(graph, multi_br_dp)
317-
self.assertEqual(find_lca_non_replicable_dp(graph), multi_br_dp)
324+
graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp)
325+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), multi_br_dp)
318326

319-
_, _, ch1, _, fork_zip_dp, *_, graph = self._make_dp()
320-
graph, ch1 = make_dp_non_replicable(graph, ch1)
321-
self.assertEqual(find_lca_non_replicable_dp(graph), fork_zip_dp)
327+
_, _, _, ch1, _, fork_zip_dp, *_, graph = self._make_dp()
328+
graph, ch1 = insert_round_robin_sharding(graph, ch1)
329+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), fork_zip_dp)
322330

323331
# Circular reference
324332
*_, cir_br_dp, cir_map_dp, _, graph = self._make_dp()
325-
graph, cir_br_dp = make_dp_non_replicable(graph, cir_br_dp)
326-
self.assertEqual(find_lca_non_replicable_dp(graph), cir_map_dp)
333+
graph, cir_br_dp = insert_round_robin_sharding(graph, cir_br_dp)
334+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), cir_map_dp)
327335

328336
*_, cir_map_dp, _, graph = self._make_dp()
329-
graph, cir_map_dp = make_dp_non_replicable(graph, cir_map_dp)
330-
self.assertEqual(find_lca_non_replicable_dp(graph), cir_map_dp)
337+
graph, cir_map_dp = insert_round_robin_sharding(graph, cir_map_dp)
338+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), cir_map_dp)
331339

332-
def test_multi_non_replicable_dps(self):
340+
def test_multi_round_robin_sharding_dps(self):
333341
single_br_dp, multi_br_dp, *_, end_dp, graph = self._make_dp()
334-
graph, single_br_dp = make_dp_non_replicable(graph, single_br_dp)
335-
graph, multi_br_dp = make_dp_non_replicable(graph, multi_br_dp)
336-
self.assertEqual(find_lca_non_replicable_dp(graph), end_dp)
342+
graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp)
343+
graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp)
344+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp)
337345

338-
single_br_dp, _, ch1, *_, end_dp, graph = self._make_dp()
339-
graph, single_br_dp = make_dp_non_replicable(graph, single_br_dp)
340-
graph, ch1 = make_dp_non_replicable(graph, ch1)
341-
self.assertEqual(find_lca_non_replicable_dp(graph), end_dp)
346+
single_br_dp, _, _, ch1, *_, end_dp, graph = self._make_dp()
347+
graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp)
348+
graph, ch1 = insert_round_robin_sharding(graph, ch1)
349+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp)
342350

343-
_, multi_br_dp, ch1, _, fork_zip_dp, *_, graph = self._make_dp()
344-
graph, multi_br_dp = make_dp_non_replicable(graph, multi_br_dp)
345-
graph, ch1 = make_dp_non_replicable(graph, ch1)
346-
self.assertEqual(find_lca_non_replicable_dp(graph), fork_zip_dp)
351+
_, multi_br_dp, _, ch1, _, fork_zip_dp, *_, graph = self._make_dp()
352+
graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp)
353+
graph, ch1 = insert_round_robin_sharding(graph, ch1)
354+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), fork_zip_dp)
347355

348356
single_br_dp, *_, cir_br_dp, _, end_dp, graph = self._make_dp()
349-
graph, single_br_dp = make_dp_non_replicable(graph, single_br_dp)
350-
graph, cir_br_dp = make_dp_non_replicable(graph, cir_br_dp)
351-
self.assertEqual(find_lca_non_replicable_dp(graph), end_dp)
357+
graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp)
358+
graph, cir_br_dp = insert_round_robin_sharding(graph, cir_br_dp)
359+
self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp)
352360

353-
def test_replicable_branches(self):
361+
def test_non_dispatching_branches(self):
354362
r"""
355363
There should be a single DataPipe as the lowest common ancestor of all
356-
non-replicable DataPipes that is replaced by ``DummyIterDataPipe``.
364+
non-dispatching DataPipes that is replaced by ``DummyIterDataPipe``.
357365
"""
358366
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
359367
graph = replace_by_dummy(graph, single_br_dp)
360-
dps = find_replicable_branches(graph)
368+
dps = find_non_dispatching_branches(graph)
361369
self.assertEqual(len(dps), 2)
362370
self.assertTrue(all(dp in (fork_zip_dp, cir_map_dp) for dp in dps))
363371

364372
single_br_dp, multi_br_dp, *_, cir_map_dp, _, graph = self._make_dp()
365373
graph = replace_by_dummy(graph, multi_br_dp)
366-
dps = find_replicable_branches(graph)
374+
dps = find_non_dispatching_branches(graph)
367375
self.assertEqual(len(dps), 2)
368376
self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps))
369377

370378
# In theory, this case should never happen because LCA (fork_zip_dp) should be
371379
# replaced by _DummpyIterDataPipe if any of child is non-replicable
372-
single_br_dp, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp()
380+
single_br_dp, _, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp()
373381
graph = replace_by_dummy(graph, ch1)
374-
dps = find_replicable_branches(graph)
382+
dps = find_non_dispatching_branches(graph)
375383
self.assertEqual(len(dps), 3)
376384
self.assertTrue(all(dp in (single_br_dp, ch2, cir_map_dp) for dp in dps))
377385

378386
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
379387
graph = replace_by_dummy(graph, cir_map_dp)
380-
dps = find_replicable_branches(graph)
388+
dps = find_non_dispatching_branches(graph)
381389
self.assertTrue(all(dp in (single_br_dp, fork_zip_dp) for dp in dps))
382390

383391
*_, end_dp, graph = self._make_dp()
384392
graph = replace_by_dummy(graph, end_dp)
385-
dps = find_replicable_branches(graph)
393+
dps = find_non_dispatching_branches(graph)
386394
self.assertEqual(len(dps), 0)
387395

388396
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
389397
graph = replace_by_dummy(graph, fork_zip_dp)
390-
dps = find_replicable_branches(graph)
398+
dps = find_non_dispatching_branches(graph)
391399
self.assertEqual(len(dps), 2)
392400
self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps))
393401

402+
def test_single_non_replicable_dp(self):
403+
# All replicable
404+
*_, end_dp, graph = self._make_dp()
405+
dps = _find_replicable_branches(graph)
406+
self.assertEqual(len(dps), 1)
407+
self.assertEqual(dps[0], end_dp)
408+
409+
# Test the production use case where the last DataPipe is fullsync
410+
*_, end_dp, _ = self._make_dp()
411+
dp = end_dp.fullsync()
412+
graph = traverse_dps(dp)
413+
dps = _find_replicable_branches(graph)
414+
self.assertEqual(len(dps), 1)
415+
self.assertEqual(dps[0], end_dp)
416+
417+
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
418+
make_non_replicable_dp(single_br_dp)
419+
dps = _find_replicable_branches(graph)
420+
self.assertEqual(len(dps), 2)
421+
self.assertTrue(all(dp in (fork_zip_dp, cir_map_dp) for dp in dps))
422+
423+
single_br_dp, *_, ch1, ch2, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
424+
make_non_replicable_dp(fork_zip_dp)
425+
dps = _find_replicable_branches(graph)
426+
self.assertEqual(len(dps), 4)
427+
self.assertTrue(all(dp in (single_br_dp, ch1, ch2, cir_map_dp) for dp in dps))
428+
429+
single_br_dp, _, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp()
430+
make_non_replicable_dp(ch1)
431+
dps = _find_replicable_branches(graph)
432+
self.assertEqual(len(dps), 3)
433+
self.assertTrue(all(dp in (single_br_dp, forker_dp, cir_map_dp) for dp in dps))
434+
435+
single_br_dp, *_, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp()
436+
make_non_replicable_dp(cir_map_dp)
437+
dps = _find_replicable_branches(graph)
438+
self.assertEqual(len(dps), 3)
439+
self.assertTrue(all(dp in (single_br_dp, fork_zip_dp, cir_br_dp) for dp in dps))
440+
441+
single_br_dp, *_, fork_zip_dp, _, cir_map_dp, end_dp, graph = self._make_dp()
442+
make_non_replicable_dp(end_dp)
443+
dps = _find_replicable_branches(graph)
444+
self.assertEqual(len(dps), 3)
445+
self.assertTrue(all(dp in (single_br_dp, fork_zip_dp, cir_map_dp) for dp in dps))
446+
447+
def test_multi_non_replicable_dps(self):
448+
single_br_dp, multi_br_dp, *_, cir_map_dp, _, graph = self._make_dp()
449+
make_non_replicable_dp(single_br_dp)
450+
make_non_replicable_dp(multi_br_dp)
451+
dps = _find_replicable_branches(graph)
452+
self.assertEqual(len(dps), 1)
453+
self.assertEqual(dps[0], cir_map_dp)
454+
455+
single_br_dp, _, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp()
456+
make_non_replicable_dp(single_br_dp)
457+
make_non_replicable_dp(ch1)
458+
dps = _find_replicable_branches(graph)
459+
self.assertEqual(len(dps), 2)
460+
self.assertTrue(all(dp in (forker_dp, cir_map_dp) for dp in dps))
461+
462+
single_br_dp, *_, ch1, ch2, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp()
463+
make_non_replicable_dp(single_br_dp)
464+
make_non_replicable_dp(fork_zip_dp)
465+
dps = _find_replicable_branches(graph)
466+
self.assertEqual(len(dps), 3)
467+
self.assertTrue(all(dp in (ch1, ch2, cir_map_dp) for dp in dps))
468+
469+
single_br_dp, *_, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp()
470+
make_non_replicable_dp(single_br_dp)
471+
make_non_replicable_dp(cir_map_dp)
472+
dps = _find_replicable_branches(graph)
473+
self.assertEqual(len(dps), 2)
474+
self.assertTrue(all(dp in (fork_zip_dp, cir_br_dp) for dp in dps))
475+
476+
single_br_dp, multi_br_dp, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp()
477+
make_non_replicable_dp(forker_dp)
478+
make_non_replicable_dp(ch1)
479+
dps = _find_replicable_branches(graph)
480+
self.assertEqual(len(dps), 3)
481+
self.assertTrue(all(dp in (single_br_dp, multi_br_dp, cir_map_dp) for dp in dps))
482+
483+
single_br_dp, multi_br_dp, forker_dp, *_, cir_br_dp, cir_map_dp, _, graph = self._make_dp()
484+
make_non_replicable_dp(forker_dp)
485+
make_non_replicable_dp(cir_map_dp)
486+
dps = _find_replicable_branches(graph)
487+
self.assertEqual(len(dps), 3)
488+
self.assertTrue(all(dp in (single_br_dp, multi_br_dp, cir_br_dp) for dp in dps))
489+
490+
single_br_dp, *_, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp()
491+
make_non_replicable_dp(fork_zip_dp)
492+
make_non_replicable_dp(cir_map_dp)
493+
dps = _find_replicable_branches(graph)
494+
self.assertEqual(len(dps), 4)
495+
self.assertTrue(all(dp in (single_br_dp, ch1, ch2, cir_br_dp) for dp in dps))
496+
394497

395498
class TestGraphVisualization(expecttest.TestCase):
396499
@unittest.skipIf(not HAS_GRAPHVIZ, "Package `graphviz` is required to test graph visualization functionalities.")

0 commit comments

Comments
 (0)