|
12 | 12 | import warnings |
13 | 13 |
|
14 | 14 | from collections import defaultdict |
| 15 | +from functools import partial |
15 | 16 | from typing import Dict |
16 | 17 |
|
17 | 18 | import expecttest |
|
20 | 21 | import torchdata |
21 | 22 |
|
22 | 23 | from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls |
| 24 | +from torch.testing._internal.common_utils import suppress_warnings |
23 | 25 |
|
24 | 26 | from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration |
25 | 27 | from torchdata.datapipes.iter import ( |
@@ -80,12 +82,12 @@ def _convert_to_tensor(data): |
80 | 82 |
|
81 | 83 |
|
82 | 84 | async def _async_mul_ten(x): |
83 | | - await asyncio.sleep(1) |
| 85 | + await asyncio.sleep(0.1) |
84 | 86 | return x * 10 |
85 | 87 |
|
86 | 88 |
|
87 | 89 | async def _async_x_mul_y(x, y): |
88 | | - await asyncio.sleep(1) |
| 90 | + await asyncio.sleep(0.1) |
89 | 91 | return x * y |
90 | 92 |
|
91 | 93 |
|
@@ -289,7 +291,7 @@ def odd_even_bug(i: int) -> int: |
289 | 291 | self.assertEqual(len(source_dp), len(result_dp)) |
290 | 292 |
|
291 | 293 | def test_prefetcher_iterdatapipe(self) -> None: |
292 | | - source_dp = IterableWrapper(range(50000)) |
| 294 | + source_dp = IterableWrapper(range(5000)) |
293 | 295 | prefetched_dp = source_dp.prefetch(10) |
294 | 296 | # check if early termination resets child thread properly |
295 | 297 | for _, _ in zip(range(100), prefetched_dp): |
@@ -1619,6 +1621,273 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ |
1619 | 1621 | self.assertEqual(v1, exp) |
1620 | 1622 | self.assertEqual(v2, exp) |
1621 | 1623 |
|
| 1624 | + def test_threadpool_map(self): |
| 1625 | + target_length = 30 |
| 1626 | + input_dp = IterableWrapper(range(target_length)) |
| 1627 | + input_dp_parallel = IterableWrapper(range(target_length)) |
| 1628 | + |
| 1629 | + def fn(item, dtype=torch.float, *, sum=False): |
| 1630 | + data = torch.tensor(item, dtype=dtype) |
| 1631 | + return data if not sum else data.sum() |
| 1632 | + |
| 1633 | + # Functional Test: apply to each element correctly |
| 1634 | + map_dp = input_dp.threadpool_map(fn) |
| 1635 | + self.assertEqual(target_length, len(map_dp)) |
| 1636 | + for x, y in zip(map_dp, range(target_length)): |
| 1637 | + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) |
| 1638 | + |
| 1639 | + # Functional Test: works with partial function |
| 1640 | + map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True)) |
| 1641 | + for x, y in zip(map_dp, range(target_length)): |
| 1642 | + self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) |
| 1643 | + |
| 1644 | + # __len__ Test: inherits length from source DataPipe |
| 1645 | + self.assertEqual(target_length, len(map_dp)) |
| 1646 | + |
| 1647 | + input_dp_nl = IDP_NoLen(range(target_length)) |
| 1648 | + map_dp_nl = input_dp_nl.threadpool_map(lambda x: x) |
| 1649 | + for x, y in zip(map_dp_nl, range(target_length)): |
| 1650 | + self.assertEqual(x, torch.tensor(y, dtype=torch.float)) |
| 1651 | + |
| 1652 | + # __len__ Test: inherits length from source DataPipe - raises error when invalid |
| 1653 | + with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| 1654 | + len(map_dp_nl) |
| 1655 | + |
| 1656 | + # Test: two independent ThreadPoolExecutors running at the same time |
| 1657 | + map_dp_parallel = input_dp_parallel.threadpool_map(fn) |
| 1658 | + for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): |
| 1659 | + self.assertEqual(x, torch.tensor(z, dtype=torch.float)) |
| 1660 | + self.assertEqual(y, torch.tensor(z, dtype=torch.float)) |
| 1661 | + |
| 1662 | + # Reset Test: DataPipe resets properly |
| 1663 | + n_elements_before_reset = 5 |
| 1664 | + res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) |
| 1665 | + self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) |
| 1666 | + self.assertEqual(list(range(target_length)), res_after_reset) |
| 1667 | + |
| 1668 | + @suppress_warnings # Suppress warning for lambda fn |
| 1669 | + def test_threadpool_map_tuple_list_with_col_iterdatapipe(self): |
| 1670 | + def fn_11(d): |
| 1671 | + return -d |
| 1672 | + |
| 1673 | + def fn_1n(d): |
| 1674 | + return -d, d |
| 1675 | + |
| 1676 | + def fn_n1(d0, d1): |
| 1677 | + return d0 + d1 |
| 1678 | + |
| 1679 | + def fn_nn(d0, d1): |
| 1680 | + return -d0, -d1, d0 + d1 |
| 1681 | + |
| 1682 | + def fn_n1_def(d0, d1=1): |
| 1683 | + return d0 + d1 |
| 1684 | + |
| 1685 | + def fn_n1_kwargs(d0, d1, **kwargs): |
| 1686 | + return d0 + d1 |
| 1687 | + |
| 1688 | + def fn_n1_pos(d0, d1, *args): |
| 1689 | + return d0 + d1 |
| 1690 | + |
| 1691 | + def fn_n1_sep_pos(d0, *args, d1): |
| 1692 | + return d0 + d1 |
| 1693 | + |
| 1694 | + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): |
| 1695 | + return d0 + d1 |
| 1696 | + |
| 1697 | + p_fn_n1 = partial(fn_n1, d1=1) |
| 1698 | + p_fn_cmplx = partial(fn_cmplx, d2=2) |
| 1699 | + |
| 1700 | + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): |
| 1701 | + for constr in (list, tuple): |
| 1702 | + datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) |
| 1703 | + if ref_fn is None: |
| 1704 | + with self.assertRaises(error): |
| 1705 | + res_dp = datapipe.threadpool_map(fn, input_col, output_col) |
| 1706 | + list(res_dp) |
| 1707 | + else: |
| 1708 | + res_dp = datapipe.threadpool_map(fn, input_col, output_col) |
| 1709 | + ref_dp = datapipe.map(ref_fn) |
| 1710 | + if constr is list: |
| 1711 | + ref_dp = ref_dp.map(list) |
| 1712 | + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") |
| 1713 | + # Reset |
| 1714 | + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") |
| 1715 | + |
| 1716 | + _helper(lambda data: data, fn_n1_def, 0, 1) |
| 1717 | + _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) |
| 1718 | + _helper(lambda data: data, p_fn_n1, 0, 1) |
| 1719 | + _helper(lambda data: data, p_fn_cmplx, 0, 1) |
| 1720 | + _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) |
| 1721 | + _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) |
| 1722 | + |
| 1723 | + # Replacing with one input column and default output column |
| 1724 | + _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) |
| 1725 | + _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) |
| 1726 | + # The index of input column is out of range |
| 1727 | + _helper(None, fn_1n, 3, error=IndexError) |
| 1728 | + # Unmatched input columns with fn arguments |
| 1729 | + _helper(None, fn_n1, 1, error=ValueError) |
| 1730 | + _helper(None, fn_n1, [0, 1, 2], error=ValueError) |
| 1731 | + _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) |
| 1732 | + _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) |
| 1733 | + _helper(None, fn_cmplx, 0, 1, ValueError) |
| 1734 | + _helper(None, fn_n1_pos, 1, error=ValueError) |
| 1735 | + _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) |
| 1736 | + _helper(None, p_fn_n1, [0, 1], error=ValueError) |
| 1737 | + _helper(None, fn_1n, [1, 2], error=ValueError) |
| 1738 | + # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) |
| 1739 | + _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) |
| 1740 | + # Fn has keyword-only arguments |
| 1741 | + _helper(None, fn_n1_kwargs, 1, error=ValueError) |
| 1742 | + _helper(None, fn_cmplx, [0, 1], 2, ValueError) |
| 1743 | + |
| 1744 | + # Replacing with multiple input columns and default output column (the left-most input column) |
| 1745 | + _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) |
| 1746 | + _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) |
| 1747 | + |
| 1748 | + # output_col can only be specified when input_col is not None |
| 1749 | + _helper(None, fn_n1, None, 1, error=ValueError) |
| 1750 | + # output_col can only be single-element list or tuple |
| 1751 | + _helper(None, fn_n1, None, [0, 1], error=ValueError) |
| 1752 | + # Single-element list as output_col |
| 1753 | + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) |
| 1754 | + # Replacing with one input column and single specified output column |
| 1755 | + _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) |
| 1756 | + _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) |
| 1757 | + # The index of output column is out of range |
| 1758 | + _helper(None, fn_1n, 1, 3, error=IndexError) |
| 1759 | + _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) |
| 1760 | + _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) |
| 1761 | + |
| 1762 | + # Appending the output at the end |
| 1763 | + _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) |
| 1764 | + _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) |
| 1765 | + _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) |
| 1766 | + _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1) |
| 1767 | + |
| 1768 | + # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected |
| 1769 | + _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) |
| 1770 | + _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) |
| 1771 | + |
| 1772 | + @suppress_warnings # Suppress warning for lambda fn |
| 1773 | + def test_threadpool_map_dict_with_col_iterdatapipe(self): |
| 1774 | + def fn_11(d): |
| 1775 | + return -d |
| 1776 | + |
| 1777 | + def fn_1n(d): |
| 1778 | + return -d, d |
| 1779 | + |
| 1780 | + def fn_n1(d0, d1): |
| 1781 | + return d0 + d1 |
| 1782 | + |
| 1783 | + def fn_nn(d0, d1): |
| 1784 | + return -d0, -d1, d0 + d1 |
| 1785 | + |
| 1786 | + def fn_n1_def(d0, d1=1): |
| 1787 | + return d0 + d1 |
| 1788 | + |
| 1789 | + p_fn_n1 = partial(fn_n1, d1=1) |
| 1790 | + |
| 1791 | + def fn_n1_pos(d0, d1, *args): |
| 1792 | + return d0 + d1 |
| 1793 | + |
| 1794 | + def fn_n1_kwargs(d0, d1, **kwargs): |
| 1795 | + return d0 + d1 |
| 1796 | + |
| 1797 | + def fn_kwonly(*, d0, d1): |
| 1798 | + return d0 + d1 |
| 1799 | + |
| 1800 | + def fn_has_nondefault_kwonly(d0, *, d1): |
| 1801 | + return d0 + d1 |
| 1802 | + |
| 1803 | + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): |
| 1804 | + return d0 + d1 |
| 1805 | + |
| 1806 | + p_fn_cmplx = partial(fn_cmplx, d2=2) |
| 1807 | + |
| 1808 | + # Prevent modification in-place to support resetting |
| 1809 | + def _dict_update(data, newdata, remove_idx=None): |
| 1810 | + _data = dict(data) |
| 1811 | + _data.update(newdata) |
| 1812 | + if remove_idx: |
| 1813 | + for idx in remove_idx: |
| 1814 | + del _data[idx] |
| 1815 | + return _data |
| 1816 | + |
| 1817 | + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): |
| 1818 | + datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) |
| 1819 | + if ref_fn is None: |
| 1820 | + with self.assertRaises(error): |
| 1821 | + res_dp = datapipe.threadpool_map(fn, input_col, output_col) |
| 1822 | + list(res_dp) |
| 1823 | + else: |
| 1824 | + res_dp = datapipe.threadpool_map(fn, input_col, output_col) |
| 1825 | + ref_dp = datapipe.map(ref_fn) |
| 1826 | + self.assertEqual(list(res_dp), list(ref_dp), "First test failed") |
| 1827 | + # Reset |
| 1828 | + self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") |
| 1829 | + |
| 1830 | + _helper(lambda data: data, fn_n1_def, "x", "y") |
| 1831 | + _helper(lambda data: data, p_fn_n1, "x", "y") |
| 1832 | + _helper(lambda data: data, p_fn_cmplx, "x", "y") |
| 1833 | + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z") |
| 1834 | + |
| 1835 | + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z") |
| 1836 | + |
| 1837 | + _helper(None, fn_n1_pos, "x", error=ValueError) |
| 1838 | + _helper(None, fn_n1_kwargs, "x", error=ValueError) |
| 1839 | + # non-default kw-only args |
| 1840 | + _helper(None, fn_kwonly, ["x", "y"], error=ValueError) |
| 1841 | + _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) |
| 1842 | + _helper(None, fn_cmplx, ["x", "y"], error=ValueError) |
| 1843 | + |
| 1844 | + # Replacing with one input column and default output column |
| 1845 | + _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") |
| 1846 | + _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") |
| 1847 | + # The key of input column is not in dict |
| 1848 | + _helper(None, fn_1n, "a", error=KeyError) |
| 1849 | + # Unmatched input columns with fn arguments |
| 1850 | + _helper(None, fn_n1, "y", error=ValueError) |
| 1851 | + _helper(None, fn_1n, ["x", "y"], error=ValueError) |
| 1852 | + _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) |
| 1853 | + _helper(None, p_fn_n1, ["x", "y"], error=ValueError) |
| 1854 | + _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) |
| 1855 | + # Replacing with multiple input columns and default output column (the left-most input column) |
| 1856 | + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) |
| 1857 | + _helper( |
| 1858 | + lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), |
| 1859 | + fn_nn, |
| 1860 | + ["z", "y"], |
| 1861 | + ) |
| 1862 | + |
| 1863 | + # output_col can only be specified when input_col is not None |
| 1864 | + _helper(None, fn_n1, None, "x", error=ValueError) |
| 1865 | + # output_col can only be single-element list or tuple |
| 1866 | + _helper(None, fn_n1, None, ["x", "y"], error=ValueError) |
| 1867 | + # Single-element list as output_col |
| 1868 | + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) |
| 1869 | + # Replacing with one input column and single specified output column |
| 1870 | + _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") |
| 1871 | + _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z") |
| 1872 | + _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y") |
| 1873 | + _helper( |
| 1874 | + lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), |
| 1875 | + fn_nn, |
| 1876 | + ["y", "z"], |
| 1877 | + "x", |
| 1878 | + ) |
| 1879 | + |
| 1880 | + # Adding new key to dict for the output |
| 1881 | + _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") |
| 1882 | + _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a") |
| 1883 | + _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a") |
| 1884 | + _helper( |
| 1885 | + lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), |
| 1886 | + fn_nn, |
| 1887 | + ["y", "z"], |
| 1888 | + "a", |
| 1889 | + ) |
| 1890 | + |
1622 | 1891 |
|
1623 | 1892 | if __name__ == "__main__": |
1624 | 1893 | unittest.main() |
0 commit comments