Skip to content

Commit 2ee03ae

Browse files
authored
add HF test again (meta-pytorch#1276)
1 parent 0aefb67 commit 2ee03ae

3 files changed

Lines changed: 111 additions & 1 deletion

File tree

.github/workflows/stateful_dataloader_ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,6 @@ jobs:
9090
- name: Run StatefulDataLoader tests with pytest - state_dict 3
9191
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
9292
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard3
93+
- name: Run StatefulDataLoader HuggingFace tests
94+
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
95+
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_hugging_face.py

test/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ portalocker >= 2.0.0
1111
# Protobuf 3.20.2 is also broken on MacOS Python 3.10
1212
# See: https://github.com/protocolbuffers/protobuf/issues/10571
1313
protobuf >= 3.9.2, < 3.20
14-
datasets
14+
datasets @ git+https://github.com/huggingface/datasets@main
1515
graphviz
1616
adlfs
1717
awscli>=1.27.66
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import itertools
2+
3+
from datasets.info import DatasetInfo
4+
from datasets.iterable_dataset import ExamplesIterable, IterableDataset
5+
from torch.testing._internal.common_utils import IS_MACOS, TestCase
6+
from torchdata.stateful_dataloader import StatefulDataLoader
7+
8+
9+
DEFAULT_N_EXAMPLES = 20
10+
DEFAULT_FILEPATH = "file.txt"
11+
12+
13+
def generate_examples_fn(**kwargs):
14+
kwargs = kwargs.copy()
15+
n = kwargs.pop("n", DEFAULT_N_EXAMPLES)
16+
filepaths = kwargs.pop("filepaths", None)
17+
for filepath in filepaths or [DEFAULT_FILEPATH]:
18+
if filepaths is not None:
19+
kwargs["filepath"] = filepath
20+
for i in range(n):
21+
yield f"{filepath}_{i}", {"id": i, **kwargs}
22+
23+
24+
def identity(x):
25+
return x
26+
27+
28+
class TestStatefulDataLoaderIterable_shard0(TestCase):
29+
def _get_dataset(self):
30+
ex_iterable = ExamplesIterable(generate_examples_fn, {})
31+
return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train")
32+
33+
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1):
34+
dataset = self._get_dataset()
35+
dl = StatefulDataLoader(
36+
dataset=dataset,
37+
num_workers=num_workers,
38+
collate_fn=identity,
39+
snapshot_every_n_steps=every_n_steps,
40+
persistent_workers=pw,
41+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
42+
)
43+
it = iter(dl)
44+
for _ in range(interrupt):
45+
next(it)
46+
47+
state_dict = dl.state_dict()
48+
exp = []
49+
for data in it:
50+
exp.append(data)
51+
52+
# Restore new instance from state
53+
batches = []
54+
dl = StatefulDataLoader(
55+
dataset=dataset,
56+
num_workers=num_workers,
57+
collate_fn=identity,
58+
snapshot_every_n_steps=every_n_steps,
59+
persistent_workers=pw,
60+
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
61+
)
62+
dl.load_state_dict(state_dict)
63+
for batch in iter(dl):
64+
batches.append(batch)
65+
66+
self.assertEqual(exp, batches)
67+
68+
def test_no_mp(self):
69+
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
70+
with self.subTest(batch_size=batch_size, interrupt=interrupt):
71+
self._run_and_checkpoint(
72+
num_workers=0,
73+
batch_size=batch_size,
74+
pw=False,
75+
interrupt=interrupt,
76+
)
77+
78+
def test_mp_x(self):
79+
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
80+
with self.subTest(batch_size=batch_size, interrupt=interrupt):
81+
self._run_and_checkpoint(
82+
num_workers=3,
83+
batch_size=batch_size,
84+
pw=False,
85+
interrupt=interrupt,
86+
)
87+
88+
def test_mp_pw(self):
89+
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
90+
with self.subTest(batch_size=batch_size, interrupt=interrupt):
91+
self._run_and_checkpoint(
92+
num_workers=3,
93+
batch_size=batch_size,
94+
pw=True,
95+
interrupt=interrupt,
96+
)
97+
98+
def test_mp_every_n_steps(self):
99+
batch_size = 7
100+
for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]):
101+
with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt):
102+
self._run_and_checkpoint(
103+
num_workers=3,
104+
batch_size=batch_size,
105+
pw=True,
106+
interrupt=interrupt,
107+
)

0 commit comments

Comments
 (0)