Skip to content

Commit fea20d4

Browse files
SvenDS9facebook-github-bot
authored andcommitted
Add ThreadPoolMapper (meta-pytorch#1052)
Summary: Fixes meta-pytorch#1045 ### Changes - Add ThreadPoolMapper datapipe - Add tests Pull Request resolved: meta-pytorch#1052 Reviewed By: NivekT Differential Revision: D44033894 Pulled By: ejguan fbshipit-source-id: 608b70a857e4610fc6616a53711c706207ce696a
1 parent f1283eb commit fea20d4

5 files changed

Lines changed: 474 additions & 3 deletions

File tree

docs/source/torchdata.datapipes.iter.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ These DataPipes apply the a given function to each element in the DataPipe.
167167
BatchMapper
168168
FlatMapper
169169
Mapper
170+
ThreadPoolMapper
170171

171172
Other DataPipes
172173
-------------------------

test/test_iterdatapipe.py

Lines changed: 272 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313

1414
from collections import defaultdict
15+
from functools import partial
1516
from typing import Dict
1617

1718
import expecttest
@@ -20,6 +21,7 @@
2021
import torchdata
2122

2223
from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls
24+
from torch.testing._internal.common_utils import suppress_warnings
2325

2426
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
2527
from torchdata.datapipes.iter import (
@@ -80,12 +82,12 @@ def _convert_to_tensor(data):
8082

8183

8284
async def _async_mul_ten(x):
83-
await asyncio.sleep(1)
85+
await asyncio.sleep(0.1)
8486
return x * 10
8587

8688

8789
async def _async_x_mul_y(x, y):
88-
await asyncio.sleep(1)
90+
await asyncio.sleep(0.1)
8991
return x * y
9092

9193

@@ -289,7 +291,7 @@ def odd_even_bug(i: int) -> int:
289291
self.assertEqual(len(source_dp), len(result_dp))
290292

291293
def test_prefetcher_iterdatapipe(self) -> None:
292-
source_dp = IterableWrapper(range(50000))
294+
source_dp = IterableWrapper(range(5000))
293295
prefetched_dp = source_dp.prefetch(10)
294296
# check if early termination resets child thread properly
295297
for _, _ in zip(range(100), prefetched_dp):
@@ -1619,6 +1621,273 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_
16191621
self.assertEqual(v1, exp)
16201622
self.assertEqual(v2, exp)
16211623

1624+
def test_threadpool_map(self):
1625+
target_length = 30
1626+
input_dp = IterableWrapper(range(target_length))
1627+
input_dp_parallel = IterableWrapper(range(target_length))
1628+
1629+
def fn(item, dtype=torch.float, *, sum=False):
1630+
data = torch.tensor(item, dtype=dtype)
1631+
return data if not sum else data.sum()
1632+
1633+
# Functional Test: apply to each element correctly
1634+
map_dp = input_dp.threadpool_map(fn)
1635+
self.assertEqual(target_length, len(map_dp))
1636+
for x, y in zip(map_dp, range(target_length)):
1637+
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1638+
1639+
# Functional Test: works with partial function
1640+
map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True))
1641+
for x, y in zip(map_dp, range(target_length)):
1642+
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
1643+
1644+
# __len__ Test: inherits length from source DataPipe
1645+
self.assertEqual(target_length, len(map_dp))
1646+
1647+
input_dp_nl = IDP_NoLen(range(target_length))
1648+
map_dp_nl = input_dp_nl.threadpool_map(lambda x: x)
1649+
for x, y in zip(map_dp_nl, range(target_length)):
1650+
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
1651+
1652+
# __len__ Test: inherits length from source DataPipe - raises error when invalid
1653+
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
1654+
len(map_dp_nl)
1655+
1656+
# Test: two independent ThreadPoolExecutors running at the same time
1657+
map_dp_parallel = input_dp_parallel.threadpool_map(fn)
1658+
for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)):
1659+
self.assertEqual(x, torch.tensor(z, dtype=torch.float))
1660+
self.assertEqual(y, torch.tensor(z, dtype=torch.float))
1661+
1662+
# Reset Test: DataPipe resets properly
1663+
n_elements_before_reset = 5
1664+
res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset)
1665+
self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
1666+
self.assertEqual(list(range(target_length)), res_after_reset)
1667+
1668+
@suppress_warnings # Suppress warning for lambda fn
1669+
def test_threadpool_map_tuple_list_with_col_iterdatapipe(self):
1670+
def fn_11(d):
1671+
return -d
1672+
1673+
def fn_1n(d):
1674+
return -d, d
1675+
1676+
def fn_n1(d0, d1):
1677+
return d0 + d1
1678+
1679+
def fn_nn(d0, d1):
1680+
return -d0, -d1, d0 + d1
1681+
1682+
def fn_n1_def(d0, d1=1):
1683+
return d0 + d1
1684+
1685+
def fn_n1_kwargs(d0, d1, **kwargs):
1686+
return d0 + d1
1687+
1688+
def fn_n1_pos(d0, d1, *args):
1689+
return d0 + d1
1690+
1691+
def fn_n1_sep_pos(d0, *args, d1):
1692+
return d0 + d1
1693+
1694+
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1695+
return d0 + d1
1696+
1697+
p_fn_n1 = partial(fn_n1, d1=1)
1698+
p_fn_cmplx = partial(fn_cmplx, d2=2)
1699+
1700+
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1701+
for constr in (list, tuple):
1702+
datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
1703+
if ref_fn is None:
1704+
with self.assertRaises(error):
1705+
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
1706+
list(res_dp)
1707+
else:
1708+
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
1709+
ref_dp = datapipe.map(ref_fn)
1710+
if constr is list:
1711+
ref_dp = ref_dp.map(list)
1712+
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
1713+
# Reset
1714+
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")
1715+
1716+
_helper(lambda data: data, fn_n1_def, 0, 1)
1717+
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
1718+
_helper(lambda data: data, p_fn_n1, 0, 1)
1719+
_helper(lambda data: data, p_fn_cmplx, 0, 1)
1720+
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
1721+
_helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])
1722+
1723+
# Replacing with one input column and default output column
1724+
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
1725+
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
1726+
# The index of input column is out of range
1727+
_helper(None, fn_1n, 3, error=IndexError)
1728+
# Unmatched input columns with fn arguments
1729+
_helper(None, fn_n1, 1, error=ValueError)
1730+
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
1731+
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
1732+
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
1733+
_helper(None, fn_cmplx, 0, 1, ValueError)
1734+
_helper(None, fn_n1_pos, 1, error=ValueError)
1735+
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
1736+
_helper(None, p_fn_n1, [0, 1], error=ValueError)
1737+
_helper(None, fn_1n, [1, 2], error=ValueError)
1738+
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
1739+
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
1740+
# Fn has keyword-only arguments
1741+
_helper(None, fn_n1_kwargs, 1, error=ValueError)
1742+
_helper(None, fn_cmplx, [0, 1], 2, ValueError)
1743+
1744+
# Replacing with multiple input columns and default output column (the left-most input column)
1745+
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
1746+
_helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])
1747+
1748+
# output_col can only be specified when input_col is not None
1749+
_helper(None, fn_n1, None, 1, error=ValueError)
1750+
# output_col can only be single-element list or tuple
1751+
_helper(None, fn_n1, None, [0, 1], error=ValueError)
1752+
# Single-element list as output_col
1753+
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
1754+
# Replacing with one input column and single specified output column
1755+
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
1756+
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
1757+
# The index of output column is out of range
1758+
_helper(None, fn_1n, 1, 3, error=IndexError)
1759+
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
1760+
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)
1761+
1762+
# Appending the output at the end
1763+
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
1764+
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
1765+
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
1766+
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)
1767+
1768+
# Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
1769+
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
1770+
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)
1771+
1772+
@suppress_warnings # Suppress warning for lambda fn
1773+
def test_threadpool_map_dict_with_col_iterdatapipe(self):
1774+
def fn_11(d):
1775+
return -d
1776+
1777+
def fn_1n(d):
1778+
return -d, d
1779+
1780+
def fn_n1(d0, d1):
1781+
return d0 + d1
1782+
1783+
def fn_nn(d0, d1):
1784+
return -d0, -d1, d0 + d1
1785+
1786+
def fn_n1_def(d0, d1=1):
1787+
return d0 + d1
1788+
1789+
p_fn_n1 = partial(fn_n1, d1=1)
1790+
1791+
def fn_n1_pos(d0, d1, *args):
1792+
return d0 + d1
1793+
1794+
def fn_n1_kwargs(d0, d1, **kwargs):
1795+
return d0 + d1
1796+
1797+
def fn_kwonly(*, d0, d1):
1798+
return d0 + d1
1799+
1800+
def fn_has_nondefault_kwonly(d0, *, d1):
1801+
return d0 + d1
1802+
1803+
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
1804+
return d0 + d1
1805+
1806+
p_fn_cmplx = partial(fn_cmplx, d2=2)
1807+
1808+
# Prevent modification in-place to support resetting
1809+
def _dict_update(data, newdata, remove_idx=None):
1810+
_data = dict(data)
1811+
_data.update(newdata)
1812+
if remove_idx:
1813+
for idx in remove_idx:
1814+
del _data[idx]
1815+
return _data
1816+
1817+
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
1818+
datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}])
1819+
if ref_fn is None:
1820+
with self.assertRaises(error):
1821+
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
1822+
list(res_dp)
1823+
else:
1824+
res_dp = datapipe.threadpool_map(fn, input_col, output_col)
1825+
ref_dp = datapipe.map(ref_fn)
1826+
self.assertEqual(list(res_dp), list(ref_dp), "First test failed")
1827+
# Reset
1828+
self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed")
1829+
1830+
_helper(lambda data: data, fn_n1_def, "x", "y")
1831+
_helper(lambda data: data, p_fn_n1, "x", "y")
1832+
_helper(lambda data: data, p_fn_cmplx, "x", "y")
1833+
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z")
1834+
1835+
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z")
1836+
1837+
_helper(None, fn_n1_pos, "x", error=ValueError)
1838+
_helper(None, fn_n1_kwargs, "x", error=ValueError)
1839+
# non-default kw-only args
1840+
_helper(None, fn_kwonly, ["x", "y"], error=ValueError)
1841+
_helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
1842+
_helper(None, fn_cmplx, ["x", "y"], error=ValueError)
1843+
1844+
# Replacing with one input column and default output column
1845+
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
1846+
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
1847+
# The key of input column is not in dict
1848+
_helper(None, fn_1n, "a", error=KeyError)
1849+
# Unmatched input columns with fn arguments
1850+
_helper(None, fn_n1, "y", error=ValueError)
1851+
_helper(None, fn_1n, ["x", "y"], error=ValueError)
1852+
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
1853+
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
1854+
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
1855+
# Replacing with multiple input columns and default output column (the left-most input column)
1856+
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
1857+
_helper(
1858+
lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]),
1859+
fn_nn,
1860+
["z", "y"],
1861+
)
1862+
1863+
# output_col can only be specified when input_col is not None
1864+
_helper(None, fn_n1, None, "x", error=ValueError)
1865+
# output_col can only be single-element list or tuple
1866+
_helper(None, fn_n1, None, ["x", "y"], error=ValueError)
1867+
# Single-element list as output_col
1868+
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
1869+
# Replacing with one input column and single specified output column
1870+
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
1871+
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
1872+
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
1873+
_helper(
1874+
lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}),
1875+
fn_nn,
1876+
["y", "z"],
1877+
"x",
1878+
)
1879+
1880+
# Adding new key to dict for the output
1881+
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
1882+
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
1883+
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
1884+
_helper(
1885+
lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}),
1886+
fn_nn,
1887+
["y", "z"],
1888+
"a",
1889+
)
1890+
16221891

16231892
if __name__ == "__main__":
16241893
unittest.main()

test/test_serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def test_serializable(self):
297297
(iterdp.TarArchiveLoader, None, (), {}),
298298
# TODO(594): Add serialization tests for optional DataPipe
299299
# (iterdp.TFRecordLoader, None, (), {}),
300+
(iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}),
300301
(iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}),
301302
(iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}),
302303
(iterdp.XzFileLoader, None, (), {}),
@@ -366,6 +367,7 @@ def test_serializable_with_dill(self):
366367
(iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}),
367368
(iterdp.OnDiskCacheHolder, (lambda x: x,), {}),
368369
(iterdp.ParagraphAggregator, (lambda x: x,), {}),
370+
(iterdp.ThreadPoolMapper, (lambda x: x,), {}),
369371
]
370372
# Skipping value comparison for these DataPipes
371373
dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator}

torchdata/datapipes/iter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
FlatMapperIterDataPipe as FlatMapper,
7373
FlattenIterDataPipe as Flattener,
7474
SliceIterDataPipe as Slicer,
75+
ThreadPoolMapperIterDataPipe as ThreadPoolMapper,
7576
)
7677
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
7778
from torchdata.datapipes.iter.util.cacheholder import (
@@ -213,6 +214,7 @@
213214
"StreamReader",
214215
"TFRecordLoader",
215216
"TarArchiveLoader",
217+
"ThreadPoolMapper",
216218
"UnBatcher",
217219
"UnZipper",
218220
"WebDataset",

0 commit comments

Comments
 (0)