@@ -894,6 +894,142 @@ def mul_fn(a, b):
894894 with self .assertRaisesRegex (TypeError , "length relies on the output of its function." ):
895895 len (flatmapped_dp )
896896
897+ def test_shuffled_flatmap_iterdatapipe (self ):
898+ source_dp = IterableWrapper (list (range (20 )))
899+
900+ def fn (e ):
901+ return [e , e * 10 ]
902+
903+ # Tests with buffer_size=1
904+ # In this case, the expected behavior is similar to flatmap
905+
906+ shuffled_flatmapped_dp = source_dp .shuffled_flatmap (fn , buffer_size = 1 )
907+ expected_list = list (itertools .chain (* [(e , e * 10 ) for e in source_dp ]))
908+
909+ self .assertEqual (expected_list , list (shuffled_flatmapped_dp ))
910+
911+ # Funtional Test: Specify input_col
912+ tuple_source_dp = IterableWrapper ([(d - 1 , d , d + 1 ) for d in range (20 )])
913+
914+ # Single input_col
915+ input_col_1_dp = tuple_source_dp .shuffled_flatmap (fn , input_col = 1 , buffer_size = 1 )
916+ self .assertEqual (expected_list , list (input_col_1_dp ))
917+
918+ # With generator as fn
919+ def gen_fn (e ):
920+ yield e
921+ yield e * 10
922+
923+ shuffled_flatmapped_dp = source_dp .shuffled_flatmap (gen_fn , buffer_size = 1 )
924+ expected_list = list (itertools .chain (* [(e , e * 10 ) for e in source_dp ]))
925+
926+ self .assertEqual (expected_list , list (shuffled_flatmapped_dp ))
927+
928+ # Multiple input_col
929+ def mul_fn (a , b ):
930+ return [a - b , b - a ]
931+
932+ input_col_2_dp = tuple_source_dp .shuffled_flatmap (mul_fn , input_col = (0 , 2 ), buffer_size = 1 )
933+ self .assertEqual (list (itertools .chain (* [(- 2 , 2 ) for _ in range (20 )])), list (input_col_2_dp ))
934+
935+ # shuffled_flatmap with no fn specified
936+ default_dp = tuple_source_dp .shuffled_flatmap (buffer_size = 1 )
937+ self .assertEqual (list (itertools .chain (* [(n - 1 , n , n + 1 ) for n in range (20 )])), list (default_dp ))
938+
939+ # shuffled_flatmap with no fn specified, multiple input_col
940+ default_dp = tuple_source_dp .shuffled_flatmap (input_col = (0 , 2 ), buffer_size = 1 )
941+ self .assertEqual (list (itertools .chain (* [(n - 1 , n + 1 ) for n in range (20 )])), list (default_dp ))
942+
943+ # shuffled_flatmap with no fn specified, some special input
944+ tuple_source_dp = IterableWrapper ([[1 , 2 , [3 , 4 ]], [5 , 6 , [7 , 8 ]]])
945+ default_dp = tuple_source_dp .shuffled_flatmap (input_col = (0 , 2 ), buffer_size = 1 )
946+ self .assertEqual ([1 , [3 , 4 ], 5 , [7 , 8 ]], list (default_dp ))
947+
948+ # Reset Test: reset the DataPipe after reading part of it
949+ n_elements_before_reset = 5
950+ res_before_reset , res_after_reset = reset_after_n_next_calls (shuffled_flatmapped_dp , n_elements_before_reset )
951+
952+ self .assertEqual (expected_list [:n_elements_before_reset ], res_before_reset )
953+ self .assertEqual (expected_list , res_after_reset )
954+
955+ # __len__ Test: length should be len(source_dp)*len(fn->out_shape) which we can't know
956+ with self .assertRaisesRegex (TypeError , "length relies on the output of its function." ):
957+ len (shuffled_flatmapped_dp )
958+
959+ # __len__ when no fn specified:
960+ dp = IterableWrapper ([[1 , 2 ], [], [3 ], [4 , 5 , 6 , [7 , 8 ]]])
961+ dp = dp .shuffled_flatmap ()
962+ self .assertEqual (len (dp ), 7 )
963+
964+ # Tests with .set_shuffle(False)
965+ # In this case, the expected behavior is similar to flatmap
966+
967+ shuffled_flatmapped_dp = source_dp .shuffled_flatmap (fn ).set_shuffle (False )
968+ expected_list = list (itertools .chain (* [(e , e * 10 ) for e in source_dp ]))
969+
970+ self .assertEqual (expected_list , list (shuffled_flatmapped_dp ))
971+
972+ # Funtional Test: Specify input_col
973+ tuple_source_dp = IterableWrapper ([(d - 1 , d , d + 1 ) for d in range (20 )])
974+
975+ # Single input_col
976+ input_col_1_dp = tuple_source_dp .shuffled_flatmap (fn , input_col = 1 , buffer_size = 1 )
977+ self .assertEqual (expected_list , list (input_col_1_dp ))
978+
979+ # Multiple input_col
980+ input_col_2_dp = tuple_source_dp .shuffled_flatmap (mul_fn , input_col = (0 , 2 )).set_shuffle (False )
981+ self .assertEqual (list (itertools .chain (* [(- 2 , 2 ) for _ in range (20 )])), list (input_col_2_dp ))
982+
983+ # shuffled_flatmap with no fn specified
984+ default_dp = tuple_source_dp .shuffled_flatmap ().set_shuffle (False )
985+ self .assertEqual (list (itertools .chain (* [(n - 1 , n , n + 1 ) for n in range (20 )])), list (default_dp ))
986+
987+ # shuffled_flatmap with no fn specified, multiple input_col
988+ default_dp = tuple_source_dp .shuffled_flatmap (input_col = (0 , 2 )).set_shuffle (False )
989+ self .assertEqual (list (itertools .chain (* [(n - 1 , n + 1 ) for n in range (20 )])), list (default_dp ))
990+
991+ # shuffled_flatmap with no fn specified, some special input
992+ tuple_source_dp = IterableWrapper ([[1 , 2 , [3 , 4 ]], [5 , 6 , [7 , 8 ]]])
993+ default_dp = tuple_source_dp .shuffled_flatmap (input_col = (0 , 2 )).set_shuffle (False )
994+ self .assertEqual ([1 , [3 , 4 ], 5 , [7 , 8 ]], list (default_dp ))
995+
996+ # Reset Test: reset the DataPipe after reading part of it
997+ n_elements_before_reset = 5
998+ res_before_reset , res_after_reset = reset_after_n_next_calls (shuffled_flatmapped_dp , n_elements_before_reset )
999+
1000+ self .assertEqual (expected_list [:n_elements_before_reset ], res_before_reset )
1001+ self .assertEqual (expected_list , res_after_reset )
1002+
1003+ # Other tests
1004+
1005+ # Test no empty buffers:
1006+ with self .assertRaises (AssertionError ):
1007+ _ = source_dp .shuffled_flatmap (buffer_size = 0 )
1008+
1009+ # Functional Test: No seed
1010+ consecutive_tuple_source_dp = IterableWrapper ([(d , d + 1 , d + 2 ) for d in range (0 , 21 , 3 )])
1011+ shuffled_flatmapped_dp = consecutive_tuple_source_dp .shuffled_flatmap ()
1012+ self .assertEqual (set (range (21 )), set (shuffled_flatmapped_dp ))
1013+
1014+ # Functional Test: With global seed
1015+ torch .manual_seed (123 )
1016+ shuffled_flatmapped_dp = tuple_source_dp .shuffled_flatmap ()
1017+ res = list (shuffled_flatmapped_dp )
1018+ torch .manual_seed (123 )
1019+ self .assertEqual (list (shuffled_flatmapped_dp ), res )
1020+
1021+ # Functional Test: Set seed
1022+ shuffled_flatmapped_dp = tuple_source_dp .shuffled_flatmap ().set_seed (123 )
1023+ res = list (shuffled_flatmapped_dp )
1024+ shuffled_flatmapped_dp .set_seed (123 )
1025+ self .assertEqual (list (shuffled_flatmapped_dp ), res )
1026+
1027+ # Reset Test:
1028+ shuffled_flatmapped_dp = tuple_source_dp .shuffled_flatmap ()
1029+ n_elements_before_reset = 5
1030+ res_before_reset , res_after_reset = reset_after_n_next_calls (shuffled_flatmapped_dp , n_elements_before_reset )
1031+ self .assertEqual (5 , len (res_before_reset ))
1032+
8971033 def test_round_robin_demux_iterdatapipe (self ):
8981034 source_dp = IterableWrapper (list (range (23 )))
8991035 with self .assertRaisesRegex (ValueError , "Expected `num_instaces`" ):
0 commit comments