From 321d32e39c113e13ba84d8d5d2e4b1e93b58b1cc Mon Sep 17 00:00:00 2001 From: flexthink Date: Thu, 12 Jun 2025 10:04:56 -0400 Subject: [PATCH 1/5] Saveable Generator: initial import --- speechbrain/utils/repro.py | 81 +++++++++++++++++++++++++++++++++++ tests/unittests/test_repro.py | 32 ++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 speechbrain/utils/repro.py create mode 100644 tests/unittests/test_repro.py diff --git a/speechbrain/utils/repro.py b/speechbrain/utils/repro.py new file mode 100644 index 0000000000..e8818cb3b7 --- /dev/null +++ b/speechbrain/utils/repro.py @@ -0,0 +1,81 @@ +"""Reproducibility tools + +Author: + * Artem Ploujnikov 2025 +""" + +import re +import speechbrain as sb +import torch + +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +@sb.utils.checkpoints.register_checkpoint_hooks +class SaveableGenerator: + """A wrapper that can be used to store the state of + the random number generator in a checkpoint. It helps + with reproducibility in long-running experiments. + + Currently, this only supports CPU and Cuda devices + natively. If you need training on other architectures, + consider implementing a custom generator. + + Running it on an unsupported device not using the Torch + generator interface will simply fail to restore the + state but will not cause an error. + + Sample usage in hparams: + ```yaml + generator: !new:model.custom_model.SaveableGenerator + checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + lr_scheduler: !ref + counter: !ref + generator: !ref *+ + ``` + + Arguments + --------- + generators : list, optional + A list of generator objects. If not provided, all + default generators for CPU and Cuda will be used + """ + + def __init__(self, generators=None): + if generators is None: + generators = { + "default": torch.default_generator + } + if torch.cuda.is_available(): + for idx, generator in torch.cuda.default_generators: + generators[f"cuda:{idx}"] = generator + self.generators = generators + + @sb.utils.checkpoints.mark_as_saver + def _save(self, path): + save_dict = { + key: generator.get_state() + for key, generator in self.generators.items() + } + torch.save(save_dict, path) + + @sb.utils.checkpoints.mark_as_loader + def _recover(self, path, end_of_epoch): + del end_of_epoch + save_dict = torch.load(path) + for key, state in save_dict.items(): + match = re.match(r"cuda:(\d+)", key) + if match: + if not torch.cuda.is_available(): + logger.warning("Unable to restore RNG for %s, CUDA unavailable", key) + continue + idx = match.group(1) + if idx > torch.cuda.device_count() - 1: + logger.warning("Unable to restore RNG for %s, device not found", key) + continue + self.generators[key].set_state(state) diff --git a/tests/unittests/test_repro.py b/tests/unittests/test_repro.py new file mode 100644 index 0000000000..c4c0bc3596 --- /dev/null +++ b/tests/unittests/test_repro.py @@ -0,0 +1,32 @@ +"""Unit tests for reproducibility utilities""" + +import torch + + +def test_repro(tmpdir): + from speechbrain.utils.repro import SaveableGenerator + from speechbrain.utils.checkpoints import Checkpointer + gen1 = torch.Generator() + gen2 = torch.Generator() + gen = SaveableGenerator({"gen1": gen1, "gen2": gen2}) + checkpointer = Checkpointer(tmpdir) + checkpointer.add_recoverable("gen", gen) + # NOTE: Move the state a bit + torch.randint(1, 10, (10,), generator=gen1) + torch.randn((3, 3), generator=gen2) + + # NOTE: Save the checkpoint and get a reference + checkpointer.save_checkpoint() + x1_ref = torch.randint(1, 10, (10,), generator=gen1) + x2_ref = torch.randn((3, 3), generator=gen2) + # NOTE: Move the state even more, simulate usage + for _ in range(5): + torch.randint(1, 10, (10,), generator=gen1) + torch.randn((3, 3), generator=gen2) + + # NOTE: Recover and compare + checkpointer.recover_if_possible() + x1 = torch.randint(1, 10, (10,), generator=gen1) + x2 = torch.randn((3, 3), generator=gen2) + assert (x1 == x1_ref).all() + assert x2.allclose(x2_ref) From dc0eac5d1e49a308d7bd0280368f5035b3c32ddb Mon Sep 17 00:00:00 2001 From: flexthink Date: Thu, 12 Jun 2025 10:13:10 -0400 Subject: [PATCH 2/5] Saveable Generator: fix a comment --- speechbrain/utils/repro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speechbrain/utils/repro.py b/speechbrain/utils/repro.py index e8818cb3b7..10d9d490c1 100644 --- a/speechbrain/utils/repro.py +++ b/speechbrain/utils/repro.py @@ -36,7 +36,7 @@ class SaveableGenerator: model: !ref lr_scheduler: !ref counter: !ref - generator: !ref *+ + generator: !ref ``` Arguments From 0eeee77d6edd38abd3fbd141253cadaf3a2d5456 Mon Sep 17 00:00:00 2001 From: flexthink Date: Thu, 12 Jun 2025 10:15:42 -0400 Subject: [PATCH 3/5] Saveable Generator: Cosmetic changes --- speechbrain/utils/repro.py | 15 +++++++++------ tests/unittests/test_repro.py | 3 ++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/speechbrain/utils/repro.py b/speechbrain/utils/repro.py index 10d9d490c1..4f13432569 100644 --- a/speechbrain/utils/repro.py +++ b/speechbrain/utils/repro.py @@ -5,9 +5,10 @@ """ import re -import speechbrain as sb + import torch +import speechbrain as sb from speechbrain.utils.logger import get_logger logger = get_logger(__name__) @@ -48,9 +49,7 @@ class SaveableGenerator: def __init__(self, generators=None): if generators is None: - generators = { - "default": torch.default_generator - } + generators = {"default": torch.default_generator} if torch.cuda.is_available(): for idx, generator in torch.cuda.default_generators: generators[f"cuda:{idx}"] = generator @@ -72,10 +71,14 @@ def _recover(self, path, end_of_epoch): match = re.match(r"cuda:(\d+)", key) if match: if not torch.cuda.is_available(): - logger.warning("Unable to restore RNG for %s, CUDA unavailable", key) + logger.warning( + "Unable to restore RNG for %s, CUDA unavailable", key + ) continue idx = match.group(1) if idx > torch.cuda.device_count() - 1: - logger.warning("Unable to restore RNG for %s, device not found", key) + logger.warning( + "Unable to restore RNG for %s, device not found", key + ) continue self.generators[key].set_state(state) diff --git a/tests/unittests/test_repro.py b/tests/unittests/test_repro.py index c4c0bc3596..5f3eadd6ee 100644 --- a/tests/unittests/test_repro.py +++ b/tests/unittests/test_repro.py @@ -4,8 +4,9 @@ def test_repro(tmpdir): - from speechbrain.utils.repro import SaveableGenerator from speechbrain.utils.checkpoints import Checkpointer + from speechbrain.utils.repro import SaveableGenerator + gen1 = torch.Generator() gen2 = torch.Generator() gen = SaveableGenerator({"gen1": gen1, "gen2": gen2}) From ea44129f1a4527007d4a16822eb9f7b095a002a4 Mon Sep 17 00:00:00 2001 From: flexthink Date: Thu, 12 Jun 2025 16:37:43 -0400 Subject: [PATCH 4/5] Saveable Generator: Address CUDA deprecation --- speechbrain/utils/repro.py | 84 +++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 19 deletions(-) diff --git a/speechbrain/utils/repro.py b/speechbrain/utils/repro.py index 4f13432569..b196bdcb9a 100644 --- a/speechbrain/utils/repro.py +++ b/speechbrain/utils/repro.py @@ -28,35 +28,32 @@ class SaveableGenerator: generator interface will simply fail to restore the state but will not cause an error. - Sample usage in hparams: - ```yaml - generator: !new:model.custom_model.SaveableGenerator - checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - lr_scheduler: !ref - counter: !ref - generator: !ref - ``` - Arguments --------- generators : list, optional - A list of generator objects. If not provided, all - default generators for CPU and Cuda will be used + A list of generator objects. If not provided, """ def __init__(self, generators=None): if generators is None: generators = {"default": torch.default_generator} if torch.cuda.is_available(): - for idx, generator in torch.cuda.default_generators: - generators[f"cuda:{idx}"] = generator + for idx in range(torch.cuda.device_count()): + generators[f"cuda:{idx}"] = _CudaDefaultGeneratorWrapper( + idx + ) + self.generators = generators @sb.utils.checkpoints.mark_as_saver - def _save(self, path): + def save(self, path): + """Save the generator state for later recovery + + Arguments + --------- + path : str, Path + Where to save. Will overwrite. + """ save_dict = { key: generator.get_state() for key, generator in self.generators.items() @@ -64,10 +61,24 @@ def _save(self, path): torch.save(save_dict, path) @sb.utils.checkpoints.mark_as_loader - def _recover(self, path, end_of_epoch): + def load(self, path, end_of_epoch): + """ + Loads the generator state if the corresponding devices are + present + + Arguments + --------- + path : str, Path + Where to load from. + end_of_epoch : bool + Whether the checkpoint was end-of-epoch or not. + """ del end_of_epoch save_dict = torch.load(path) for key, state in save_dict.items(): + if key == "default": + torch.default_generator.set_state(state) + continue match = re.match(r"cuda:(\d+)", key) if match: if not torch.cuda.is_available(): @@ -75,10 +86,45 @@ def _recover(self, path, end_of_epoch): "Unable to restore RNG for %s, CUDA unavailable", key ) continue - idx = match.group(1) + idx = int(match.group(1)) if idx > torch.cuda.device_count() - 1: logger.warning( "Unable to restore RNG for %s, device not found", key ) continue self.generators[key].set_state(state) + + +class _CudaDefaultGeneratorWrapper: + """A generator wrapper for default generators - because torch no longer + exposes default_generators + + This class should not be used outside of SaveableGenerator + + Arguments + --------- + device : int|str + The device index or identifier""" + + def __init__(self, device): + self.device = device + + def get_state(self): + """Returns the generator state + + Returns + ------- + result : torch.Tensor + The generator state + """ + return torch.cuda.get_rng_state(self.device) + + def set_state(self, new_state): + """ "Sets the generator state + + Arguments + --------- + new_state : dict + The new state + """ + torch.cuda.set_rng_state(new_state, self.device) From faeb2e3981e67aaa22aeabf4bc9e4c055364707a Mon Sep 17 00:00:00 2001 From: flexthink Date: Mon, 23 Jun 2025 11:41:25 -0400 Subject: [PATCH 5/5] Saveable Generator: Update comments, add a device-based test --- speechbrain/utils/repro.py | 46 +++++++++++++++++++++++++++++++++-- tests/unittests/test_repro.py | 27 ++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/speechbrain/utils/repro.py b/speechbrain/utils/repro.py index b196bdcb9a..d6d7b57820 100644 --- a/speechbrain/utils/repro.py +++ b/speechbrain/utils/repro.py @@ -28,10 +28,52 @@ class SaveableGenerator: generator interface will simply fail to restore the state but will not cause an error. + Typical in hparams: + ```yaml + generator: !new:model.custom_model.SaveableGenerator # <-- Include the wrapper + + checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + lr_scheduler: !ref + counter: !ref + generator: !ref + ``` + Arguments --------- - generators : list, optional - A list of generator objects. If not provided, + generators : Mapping[str, Generator], optional + A dictionary of named generator objects. If not provided, + the default generators for CPU and Cuda will be used + + Examples + -------- + >>> import torch + >>> from speechbrain.utils.repro import SaveableGenerator + >>> from speechbrain.utils.checkpoints import Checkpointer + >>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]] + >>> saveable_gen = SaveableGenerator( + ... generators={"a": gena, "b": genb} + ... ) + >>> tempdir = getfixture('tmpdir') + >>> checkpointer = Checkpointer( + ... tempdir, + ... recoverables={"generator": saveable_gen}) + >>> torch.randint(0, 10, (1,), generator=gena).item() + 2 + >>> torch.randint(0, 10, (1,), generator=genb).item() + 4 + >>> _ = checkpointer.save_checkpoint() + >>> torch.randint(0, 10, (1,), generator=gena).item() + 7 + >>> torch.randint(0, 10, (1,), generator=genb).item() + 5 + >>> _ = checkpointer.recover_if_possible() + >>> torch.randint(0, 10, (1,), generator=gena).item() + 7 + >>> torch.randint(0, 10, (1,), generator=genb).item() + 5 """ def __init__(self, generators=None): diff --git a/tests/unittests/test_repro.py b/tests/unittests/test_repro.py index 5f3eadd6ee..27799ad2e2 100644 --- a/tests/unittests/test_repro.py +++ b/tests/unittests/test_repro.py @@ -1,5 +1,7 @@ """Unit tests for reproducibility utilities""" +import warnings + import torch @@ -31,3 +33,28 @@ def test_repro(tmpdir): x2 = torch.randn((3, 3), generator=gen2) assert (x1 == x1_ref).all() assert x2.allclose(x2_ref) + + +def test_repro_with_device(tmpdir, device): + from speechbrain.utils.checkpoints import Checkpointer + from speechbrain.utils.repro import SaveableGenerator + + if device == "cpu" or device.startswith("cuda"): + gen = SaveableGenerator() + checkpointer = Checkpointer(tmpdir, recoverables={"gen": gen}) + for _ in range(10): + torch.randint(0, 10, (20, 20), device=device) + torch.rand((10, 10)) + checkpointer.save_checkpoint() + x = torch.randint(0, 10, (20, 20), device=device) + y = torch.rand((10, 10)) + checkpointer.recover_if_possible() + x_check = torch.randint(0, 10, (20, 20), device=device) + y_check = torch.rand((10, 10)) + assert (x == x_check).all() + assert y.allclose(y_check) + + else: + warnings.warn( + f"Device {device} is currently unsupported for saveable generations" + )