44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import types
78import unittest
89
910from typing import Dict , Iterator , List , Tuple , TypeVar
1718
1819from torchdata .dataloader2 import DataLoader2 , MultiProcessingReadingService , ReadingServiceInterface
1920from torchdata .dataloader2 .graph import find_dps , list_dps , remove_dp , replace_dp , traverse_dps
21+ from torchdata .dataloader2 .graph .utils import _find_replicable_branches
2022from torchdata .dataloader2 .random import SeedGenerator
2123from torchdata .dataloader2 .utils .dispatch import (
2224 _DummyIterDataPipe ,
23- find_lca_non_replicable_dp ,
24- find_replicable_branches ,
25+ find_lca_round_robin_sharding_dp ,
26+ find_non_dispatching_branches ,
2527)
2628from torchdata .datapipes .iter import IterableWrapper , Mapper , ShardingRoundRobinDispatcher
2729from torchdata .datapipes .utils import to_graph
@@ -262,15 +264,20 @@ def test_multiprocessing_reading_service(self) -> None:
262264 self .assertEqual (d1 , d2 )
263265
264266
265- def make_dp_non_replicable (graph , datapipe ):
266- non_rep_dp = ShardingRoundRobinDispatcher (datapipe , SHARDING_PRIORITIES .MULTIPROCESSING )
267- return replace_dp (graph , datapipe , non_rep_dp ), non_rep_dp
267+ def insert_round_robin_sharding (graph , datapipe ):
268+ dispatch_dp = ShardingRoundRobinDispatcher (datapipe , SHARDING_PRIORITIES .MULTIPROCESSING )
269+ return replace_dp (graph , datapipe , dispatch_dp ), dispatch_dp
268270
269271
270272def replace_by_dummy (graph , datapipe ):
271273 return replace_dp (graph , datapipe , _DummyIterDataPipe ())
272274
273275
276+ def make_non_replicable_dp (datapipe ):
277+ datapipe .is_replicable = types .MethodType (lambda self : False , datapipe )
278+ return datapipe
279+
280+
274281class TestNonReplicableDataPipe (expecttest .TestCase ):
275282 def _make_dp (self ):
276283 r"""
@@ -279,21 +286,22 @@ def _make_dp(self):
279286 - multi-branch pipeline
280287 - pipeline that has circurlar references
281288
282- single_br_dp ---------------------------------
283- ch1 \
284- / \ \
285- multi_br_dp ------- --> -> fork_zip_dp -> end_dp ->
286- \ / /
287- <------- ch2 /
288- / \ /
289- cir_br_dp -> cir_map_dp ----------------------
289+ single_br_dp -------------------------------------
290+ ch1 \
291+ / \ \
292+ multi_br_dp -->forker_dp --> -> fork_zip_dp -> end_dp ->
293+ \ / /
294+ <------- ch2 /
295+ / \ /
296+ cir_br_dp -> cir_map_dp --------------------------
290297 """
291298 # Single-branch
292299 single_br_dp = IterableWrapper (list (range (10 )))
293300
294301 # Multi-branch
295302 multi_br_dp = IterableWrapper (list (range (10 )))
296303 ch1 , ch2 = multi_br_dp .fork (2 )
304+ forker_dp = ch1 .main_datapipe
297305 fork_zip_dp = ch1 .zip (ch2 )
298306
299307 # Circular-branch
@@ -304,93 +312,188 @@ def _make_dp(self):
304312
305313 end_dp = single_br_dp .zip (fork_zip_dp , cir_map_dp )
306314 graph = traverse_dps (end_dp )
307- return single_br_dp , multi_br_dp , ch1 , ch2 , fork_zip_dp , cir_br_dp , cir_map_dp , end_dp , graph
315+ return single_br_dp , multi_br_dp , forker_dp , ch1 , ch2 , fork_zip_dp , cir_br_dp , cir_map_dp , end_dp , graph
308316
309- def test_single_non_replicable_dp (self ):
317+ def test_single_round_robin_sharding_dp (self ):
310318 single_br_dp , * _ , graph = self ._make_dp ()
311- graph , single_br_dp = make_dp_non_replicable (graph , single_br_dp )
312- self .assertEqual (find_lca_non_replicable_dp (graph ), single_br_dp )
319+ graph , single_br_dp = insert_round_robin_sharding (graph , single_br_dp )
320+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), single_br_dp )
313321
314322 # The same non-shardable DataPipe on both branches
315323 _ , multi_br_dp , * _ , graph = self ._make_dp ()
316- graph , multi_br_dp = make_dp_non_replicable (graph , multi_br_dp )
317- self .assertEqual (find_lca_non_replicable_dp (graph ), multi_br_dp )
324+ graph , multi_br_dp = insert_round_robin_sharding (graph , multi_br_dp )
325+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), multi_br_dp )
318326
319- _ , _ , ch1 , _ , fork_zip_dp , * _ , graph = self ._make_dp ()
320- graph , ch1 = make_dp_non_replicable (graph , ch1 )
321- self .assertEqual (find_lca_non_replicable_dp (graph ), fork_zip_dp )
327+ _ , _ , _ , ch1 , _ , fork_zip_dp , * _ , graph = self ._make_dp ()
328+ graph , ch1 = insert_round_robin_sharding (graph , ch1 )
329+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), fork_zip_dp )
322330
323331 # Circular reference
324332 * _ , cir_br_dp , cir_map_dp , _ , graph = self ._make_dp ()
325- graph , cir_br_dp = make_dp_non_replicable (graph , cir_br_dp )
326- self .assertEqual (find_lca_non_replicable_dp (graph ), cir_map_dp )
333+ graph , cir_br_dp = insert_round_robin_sharding (graph , cir_br_dp )
334+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), cir_map_dp )
327335
328336 * _ , cir_map_dp , _ , graph = self ._make_dp ()
329- graph , cir_map_dp = make_dp_non_replicable (graph , cir_map_dp )
330- self .assertEqual (find_lca_non_replicable_dp (graph ), cir_map_dp )
337+ graph , cir_map_dp = insert_round_robin_sharding (graph , cir_map_dp )
338+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), cir_map_dp )
331339
332- def test_multi_non_replicable_dps (self ):
340+ def test_multi_round_robin_sharding_dps (self ):
333341 single_br_dp , multi_br_dp , * _ , end_dp , graph = self ._make_dp ()
334- graph , single_br_dp = make_dp_non_replicable (graph , single_br_dp )
335- graph , multi_br_dp = make_dp_non_replicable (graph , multi_br_dp )
336- self .assertEqual (find_lca_non_replicable_dp (graph ), end_dp )
342+ graph , single_br_dp = insert_round_robin_sharding (graph , single_br_dp )
343+ graph , multi_br_dp = insert_round_robin_sharding (graph , multi_br_dp )
344+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), end_dp )
337345
338- single_br_dp , _ , ch1 , * _ , end_dp , graph = self ._make_dp ()
339- graph , single_br_dp = make_dp_non_replicable (graph , single_br_dp )
340- graph , ch1 = make_dp_non_replicable (graph , ch1 )
341- self .assertEqual (find_lca_non_replicable_dp (graph ), end_dp )
346+ single_br_dp , _ , _ , ch1 , * _ , end_dp , graph = self ._make_dp ()
347+ graph , single_br_dp = insert_round_robin_sharding (graph , single_br_dp )
348+ graph , ch1 = insert_round_robin_sharding (graph , ch1 )
349+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), end_dp )
342350
343- _ , multi_br_dp , ch1 , _ , fork_zip_dp , * _ , graph = self ._make_dp ()
344- graph , multi_br_dp = make_dp_non_replicable (graph , multi_br_dp )
345- graph , ch1 = make_dp_non_replicable (graph , ch1 )
346- self .assertEqual (find_lca_non_replicable_dp (graph ), fork_zip_dp )
351+ _ , multi_br_dp , _ , ch1 , _ , fork_zip_dp , * _ , graph = self ._make_dp ()
352+ graph , multi_br_dp = insert_round_robin_sharding (graph , multi_br_dp )
353+ graph , ch1 = insert_round_robin_sharding (graph , ch1 )
354+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), fork_zip_dp )
347355
348356 single_br_dp , * _ , cir_br_dp , _ , end_dp , graph = self ._make_dp ()
349- graph , single_br_dp = make_dp_non_replicable (graph , single_br_dp )
350- graph , cir_br_dp = make_dp_non_replicable (graph , cir_br_dp )
351- self .assertEqual (find_lca_non_replicable_dp (graph ), end_dp )
357+ graph , single_br_dp = insert_round_robin_sharding (graph , single_br_dp )
358+ graph , cir_br_dp = insert_round_robin_sharding (graph , cir_br_dp )
359+ self .assertEqual (find_lca_round_robin_sharding_dp (graph ), end_dp )
352360
353- def test_replicable_branches (self ):
361+ def test_non_dispatching_branches (self ):
354362 r"""
355363 There should be a single DataPipe as the lowest common ancestor of all
356- non-replicable DataPipes that is replaced by ``DummyIterDataPipe``.
364+ non-dispatching DataPipes that is replaced by ``DummyIterDataPipe``.
357365 """
358366 single_br_dp , * _ , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
359367 graph = replace_by_dummy (graph , single_br_dp )
360- dps = find_replicable_branches (graph )
368+ dps = find_non_dispatching_branches (graph )
361369 self .assertEqual (len (dps ), 2 )
362370 self .assertTrue (all (dp in (fork_zip_dp , cir_map_dp ) for dp in dps ))
363371
364372 single_br_dp , multi_br_dp , * _ , cir_map_dp , _ , graph = self ._make_dp ()
365373 graph = replace_by_dummy (graph , multi_br_dp )
366- dps = find_replicable_branches (graph )
374+ dps = find_non_dispatching_branches (graph )
367375 self .assertEqual (len (dps ), 2 )
368376 self .assertTrue (all (dp in (single_br_dp , cir_map_dp ) for dp in dps ))
369377
370378 # In theory, this case should never happen because LCA (fork_zip_dp) should be
371379 # replaced by _DummpyIterDataPipe if any of child is non-replicable
372- single_br_dp , _ , ch1 , ch2 , * _ , cir_map_dp , _ , graph = self ._make_dp ()
380+ single_br_dp , _ , _ , ch1 , ch2 , * _ , cir_map_dp , _ , graph = self ._make_dp ()
373381 graph = replace_by_dummy (graph , ch1 )
374- dps = find_replicable_branches (graph )
382+ dps = find_non_dispatching_branches (graph )
375383 self .assertEqual (len (dps ), 3 )
376384 self .assertTrue (all (dp in (single_br_dp , ch2 , cir_map_dp ) for dp in dps ))
377385
378386 single_br_dp , * _ , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
379387 graph = replace_by_dummy (graph , cir_map_dp )
380- dps = find_replicable_branches (graph )
388+ dps = find_non_dispatching_branches (graph )
381389 self .assertTrue (all (dp in (single_br_dp , fork_zip_dp ) for dp in dps ))
382390
383391 * _ , end_dp , graph = self ._make_dp ()
384392 graph = replace_by_dummy (graph , end_dp )
385- dps = find_replicable_branches (graph )
393+ dps = find_non_dispatching_branches (graph )
386394 self .assertEqual (len (dps ), 0 )
387395
388396 single_br_dp , * _ , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
389397 graph = replace_by_dummy (graph , fork_zip_dp )
390- dps = find_replicable_branches (graph )
398+ dps = find_non_dispatching_branches (graph )
391399 self .assertEqual (len (dps ), 2 )
392400 self .assertTrue (all (dp in (single_br_dp , cir_map_dp ) for dp in dps ))
393401
402+ def test_single_non_replicable_dp (self ):
403+ # All replicable
404+ * _ , end_dp , graph = self ._make_dp ()
405+ dps = _find_replicable_branches (graph )
406+ self .assertEqual (len (dps ), 1 )
407+ self .assertEqual (dps [0 ], end_dp )
408+
409+ # Test the production use case where the last DataPipe is fullsync
410+ * _ , end_dp , _ = self ._make_dp ()
411+ dp = end_dp .fullsync ()
412+ graph = traverse_dps (dp )
413+ dps = _find_replicable_branches (graph )
414+ self .assertEqual (len (dps ), 1 )
415+ self .assertEqual (dps [0 ], end_dp )
416+
417+ single_br_dp , * _ , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
418+ make_non_replicable_dp (single_br_dp )
419+ dps = _find_replicable_branches (graph )
420+ self .assertEqual (len (dps ), 2 )
421+ self .assertTrue (all (dp in (fork_zip_dp , cir_map_dp ) for dp in dps ))
422+
423+ single_br_dp , * _ , ch1 , ch2 , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
424+ make_non_replicable_dp (fork_zip_dp )
425+ dps = _find_replicable_branches (graph )
426+ self .assertEqual (len (dps ), 4 )
427+ self .assertTrue (all (dp in (single_br_dp , ch1 , ch2 , cir_map_dp ) for dp in dps ))
428+
429+ single_br_dp , _ , forker_dp , ch1 , * _ , cir_map_dp , _ , graph = self ._make_dp ()
430+ make_non_replicable_dp (ch1 )
431+ dps = _find_replicable_branches (graph )
432+ self .assertEqual (len (dps ), 3 )
433+ self .assertTrue (all (dp in (single_br_dp , forker_dp , cir_map_dp ) for dp in dps ))
434+
435+ single_br_dp , * _ , fork_zip_dp , cir_br_dp , cir_map_dp , _ , graph = self ._make_dp ()
436+ make_non_replicable_dp (cir_map_dp )
437+ dps = _find_replicable_branches (graph )
438+ self .assertEqual (len (dps ), 3 )
439+ self .assertTrue (all (dp in (single_br_dp , fork_zip_dp , cir_br_dp ) for dp in dps ))
440+
441+ single_br_dp , * _ , fork_zip_dp , _ , cir_map_dp , end_dp , graph = self ._make_dp ()
442+ make_non_replicable_dp (end_dp )
443+ dps = _find_replicable_branches (graph )
444+ self .assertEqual (len (dps ), 3 )
445+ self .assertTrue (all (dp in (single_br_dp , fork_zip_dp , cir_map_dp ) for dp in dps ))
446+
447+ def test_multi_non_replicable_dps (self ):
448+ single_br_dp , multi_br_dp , * _ , cir_map_dp , _ , graph = self ._make_dp ()
449+ make_non_replicable_dp (single_br_dp )
450+ make_non_replicable_dp (multi_br_dp )
451+ dps = _find_replicable_branches (graph )
452+ self .assertEqual (len (dps ), 1 )
453+ self .assertEqual (dps [0 ], cir_map_dp )
454+
455+ single_br_dp , _ , forker_dp , ch1 , * _ , cir_map_dp , _ , graph = self ._make_dp ()
456+ make_non_replicable_dp (single_br_dp )
457+ make_non_replicable_dp (ch1 )
458+ dps = _find_replicable_branches (graph )
459+ self .assertEqual (len (dps ), 2 )
460+ self .assertTrue (all (dp in (forker_dp , cir_map_dp ) for dp in dps ))
461+
462+ single_br_dp , * _ , ch1 , ch2 , fork_zip_dp , _ , cir_map_dp , _ , graph = self ._make_dp ()
463+ make_non_replicable_dp (single_br_dp )
464+ make_non_replicable_dp (fork_zip_dp )
465+ dps = _find_replicable_branches (graph )
466+ self .assertEqual (len (dps ), 3 )
467+ self .assertTrue (all (dp in (ch1 , ch2 , cir_map_dp ) for dp in dps ))
468+
469+ single_br_dp , * _ , fork_zip_dp , cir_br_dp , cir_map_dp , _ , graph = self ._make_dp ()
470+ make_non_replicable_dp (single_br_dp )
471+ make_non_replicable_dp (cir_map_dp )
472+ dps = _find_replicable_branches (graph )
473+ self .assertEqual (len (dps ), 2 )
474+ self .assertTrue (all (dp in (fork_zip_dp , cir_br_dp ) for dp in dps ))
475+
476+ single_br_dp , multi_br_dp , forker_dp , ch1 , * _ , cir_map_dp , _ , graph = self ._make_dp ()
477+ make_non_replicable_dp (forker_dp )
478+ make_non_replicable_dp (ch1 )
479+ dps = _find_replicable_branches (graph )
480+ self .assertEqual (len (dps ), 3 )
481+ self .assertTrue (all (dp in (single_br_dp , multi_br_dp , cir_map_dp ) for dp in dps ))
482+
483+ single_br_dp , multi_br_dp , forker_dp , * _ , cir_br_dp , cir_map_dp , _ , graph = self ._make_dp ()
484+ make_non_replicable_dp (forker_dp )
485+ make_non_replicable_dp (cir_map_dp )
486+ dps = _find_replicable_branches (graph )
487+ self .assertEqual (len (dps ), 3 )
488+ self .assertTrue (all (dp in (single_br_dp , multi_br_dp , cir_br_dp ) for dp in dps ))
489+
490+ single_br_dp , * _ , ch1 , ch2 , fork_zip_dp , cir_br_dp , cir_map_dp , _ , graph = self ._make_dp ()
491+ make_non_replicable_dp (fork_zip_dp )
492+ make_non_replicable_dp (cir_map_dp )
493+ dps = _find_replicable_branches (graph )
494+ self .assertEqual (len (dps ), 4 )
495+ self .assertTrue (all (dp in (single_br_dp , ch1 , ch2 , cir_br_dp ) for dp in dps ))
496+
394497
395498class TestGraphVisualization (expecttest .TestCase ):
396499 @unittest .skipIf (not HAS_GRAPHVIZ , "Package `graphviz` is required to test graph visualization functionalities." )
0 commit comments