-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Saveable Generator: initial import #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
321d32e
Saveable Generator: initial import
flexthink dc0eac5
Saveable Generator: fix a comment
flexthink 0eeee77
Saveable Generator: Cosmetic changes
flexthink ea44129
Saveable Generator: Address CUDA deprecation
flexthink faeb2e3
Saveable Generator: Update comments, add a device-based test
flexthink File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| """Reproducibility tools | ||
|
|
||
| Author: | ||
| * Artem Ploujnikov 2025 | ||
| """ | ||
|
|
||
| import re | ||
|
|
||
| import torch | ||
|
|
||
| import speechbrain as sb | ||
| from speechbrain.utils.logger import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| @sb.utils.checkpoints.register_checkpoint_hooks | ||
| class SaveableGenerator: | ||
| """A wrapper that can be used to store the state of | ||
| the random number generator in a checkpoint. It helps | ||
| with reproducibility in long-running experiments. | ||
|
|
||
| Currently, this only supports CPU and Cuda devices | ||
| natively. If you need training on other architectures, | ||
| consider implementing a custom generator. | ||
|
|
||
| Running it on an unsupported device not using the Torch | ||
| generator interface will simply fail to restore the | ||
| state but will not cause an error. | ||
|
|
||
| Typical in hparams: | ||
| ```yaml | ||
| generator: !new:model.custom_model.SaveableGenerator # <-- Include the wrapper | ||
|
|
||
| checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer | ||
| checkpoints_dir: !ref <save_folder> | ||
| recoverables: | ||
| model: !ref <model> | ||
| lr_scheduler: !ref <lr_annealing> | ||
| counter: !ref <epoch_counter> | ||
| generator: !ref <generator> | ||
| ``` | ||
|
|
||
| Arguments | ||
| --------- | ||
| generators : Mapping[str, Generator], optional | ||
| A dictionary of named generator objects. If not provided, | ||
| the default generators for CPU and Cuda will be used | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import torch | ||
| >>> from speechbrain.utils.repro import SaveableGenerator | ||
| >>> from speechbrain.utils.checkpoints import Checkpointer | ||
| >>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]] | ||
| >>> saveable_gen = SaveableGenerator( | ||
| ... generators={"a": gena, "b": genb} | ||
| ... ) | ||
| >>> tempdir = getfixture('tmpdir') | ||
| >>> checkpointer = Checkpointer( | ||
| ... tempdir, | ||
| ... recoverables={"generator": saveable_gen}) | ||
| >>> torch.randint(0, 10, (1,), generator=gena).item() | ||
| 2 | ||
| >>> torch.randint(0, 10, (1,), generator=genb).item() | ||
| 4 | ||
| >>> _ = checkpointer.save_checkpoint() | ||
| >>> torch.randint(0, 10, (1,), generator=gena).item() | ||
| 7 | ||
| >>> torch.randint(0, 10, (1,), generator=genb).item() | ||
| 5 | ||
| >>> _ = checkpointer.recover_if_possible() | ||
| >>> torch.randint(0, 10, (1,), generator=gena).item() | ||
| 7 | ||
| >>> torch.randint(0, 10, (1,), generator=genb).item() | ||
| 5 | ||
| """ | ||
|
|
||
| def __init__(self, generators=None): | ||
| if generators is None: | ||
| generators = {"default": torch.default_generator} | ||
| if torch.cuda.is_available(): | ||
| for idx in range(torch.cuda.device_count()): | ||
| generators[f"cuda:{idx}"] = _CudaDefaultGeneratorWrapper( | ||
| idx | ||
| ) | ||
|
|
||
| self.generators = generators | ||
|
|
||
| @sb.utils.checkpoints.mark_as_saver | ||
| def save(self, path): | ||
| """Save the generator state for later recovery | ||
|
|
||
| Arguments | ||
| --------- | ||
| path : str, Path | ||
| Where to save. Will overwrite. | ||
| """ | ||
| save_dict = { | ||
| key: generator.get_state() | ||
| for key, generator in self.generators.items() | ||
| } | ||
| torch.save(save_dict, path) | ||
|
|
||
| @sb.utils.checkpoints.mark_as_loader | ||
| def load(self, path, end_of_epoch): | ||
| """ | ||
| Loads the generator state if the corresponding devices are | ||
| present | ||
|
|
||
| Arguments | ||
| --------- | ||
| path : str, Path | ||
| Where to load from. | ||
| end_of_epoch : bool | ||
| Whether the checkpoint was end-of-epoch or not. | ||
| """ | ||
| del end_of_epoch | ||
| save_dict = torch.load(path) | ||
| for key, state in save_dict.items(): | ||
| if key == "default": | ||
| torch.default_generator.set_state(state) | ||
| continue | ||
| match = re.match(r"cuda:(\d+)", key) | ||
| if match: | ||
| if not torch.cuda.is_available(): | ||
| logger.warning( | ||
| "Unable to restore RNG for %s, CUDA unavailable", key | ||
| ) | ||
| continue | ||
| idx = int(match.group(1)) | ||
| if idx > torch.cuda.device_count() - 1: | ||
| logger.warning( | ||
| "Unable to restore RNG for %s, device not found", key | ||
| ) | ||
| continue | ||
| self.generators[key].set_state(state) | ||
|
|
||
|
|
||
| class _CudaDefaultGeneratorWrapper: | ||
| """A generator wrapper for default generators - because torch no longer | ||
| exposes default_generators | ||
|
|
||
| This class should not be used outside of SaveableGenerator | ||
|
|
||
| Arguments | ||
| --------- | ||
| device : int|str | ||
| The device index or identifier""" | ||
|
|
||
| def __init__(self, device): | ||
| self.device = device | ||
|
|
||
| def get_state(self): | ||
| """Returns the generator state | ||
|
|
||
| Returns | ||
| ------- | ||
| result : torch.Tensor | ||
| The generator state | ||
| """ | ||
| return torch.cuda.get_rng_state(self.device) | ||
|
|
||
| def set_state(self, new_state): | ||
| """ "Sets the generator state | ||
|
|
||
| Arguments | ||
| --------- | ||
| new_state : dict | ||
| The new state | ||
| """ | ||
| torch.cuda.set_rng_state(new_state, self.device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """Unit tests for reproducibility utilities""" | ||
|
|
||
| import warnings | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def test_repro(tmpdir): | ||
| from speechbrain.utils.checkpoints import Checkpointer | ||
| from speechbrain.utils.repro import SaveableGenerator | ||
|
|
||
| gen1 = torch.Generator() | ||
| gen2 = torch.Generator() | ||
| gen = SaveableGenerator({"gen1": gen1, "gen2": gen2}) | ||
| checkpointer = Checkpointer(tmpdir) | ||
| checkpointer.add_recoverable("gen", gen) | ||
| # NOTE: Move the state a bit | ||
| torch.randint(1, 10, (10,), generator=gen1) | ||
| torch.randn((3, 3), generator=gen2) | ||
|
|
||
| # NOTE: Save the checkpoint and get a reference | ||
| checkpointer.save_checkpoint() | ||
| x1_ref = torch.randint(1, 10, (10,), generator=gen1) | ||
| x2_ref = torch.randn((3, 3), generator=gen2) | ||
| # NOTE: Move the state even more, simulate usage | ||
| for _ in range(5): | ||
| torch.randint(1, 10, (10,), generator=gen1) | ||
| torch.randn((3, 3), generator=gen2) | ||
|
|
||
| # NOTE: Recover and compare | ||
| checkpointer.recover_if_possible() | ||
| x1 = torch.randint(1, 10, (10,), generator=gen1) | ||
| x2 = torch.randn((3, 3), generator=gen2) | ||
| assert (x1 == x1_ref).all() | ||
| assert x2.allclose(x2_ref) | ||
|
|
||
|
|
||
| def test_repro_with_device(tmpdir, device): | ||
| from speechbrain.utils.checkpoints import Checkpointer | ||
| from speechbrain.utils.repro import SaveableGenerator | ||
|
|
||
| if device == "cpu" or device.startswith("cuda"): | ||
| gen = SaveableGenerator() | ||
| checkpointer = Checkpointer(tmpdir, recoverables={"gen": gen}) | ||
| for _ in range(10): | ||
| torch.randint(0, 10, (20, 20), device=device) | ||
| torch.rand((10, 10)) | ||
| checkpointer.save_checkpoint() | ||
| x = torch.randint(0, 10, (20, 20), device=device) | ||
| y = torch.rand((10, 10)) | ||
| checkpointer.recover_if_possible() | ||
| x_check = torch.randint(0, 10, (20, 20), device=device) | ||
| y_check = torch.rand((10, 10)) | ||
| assert (x == x_check).all() | ||
| assert y.allclose(y_check) | ||
|
|
||
| else: | ||
| warnings.warn( | ||
| f"Device {device} is currently unsupported for saveable generations" | ||
| ) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of our unit tests take a device parameter for testing on e.g. cuda. Perhaps we can do something similar here to ensure it works at least locally (I guess the CI is running on cpu).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pplantinga : Done in a separate test. However, device support is currently limited given that:
So for now, this feature will be only for the most common use cases. Other devices can be added later by writing wrappers similar to the one I had for Cuda - if they support the functionality at all.