[train][checkpoint] Add checkpoint_upload_mode to ray.train.report#55637
[train][checkpoint] Add checkpoint_upload_mode to ray.train.report#55637justinvyu merged 19 commits intoray-project:masterfrom
Conversation
Signed-off-by: Timothy Seah <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request introduces asynchronous checkpointing, a valuable feature for improving performance. The implementation is mostly well-structured, but I've identified a critical issue in the asynchronous handling logic that could lead to deadlocks under failure conditions. I've provided a detailed comment and a suggested fix for this. Additionally, there's a minor type hint mismatch that should be corrected for code clarity and correctness. It would also be beneficial to add tests for failure scenarios in asynchronous checkpointing to ensure the system's robustness.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
|
Looks great! Some suggestions for
Happy to work on a follow-up PR branch with these fixes (try/finally + retries/timeouts + config), or help you with any of these. |
|
Thanks for the awesome feedback!
Good callout - I am currently working on #55756, after which I will be able to propagate the exception up to the main thread. Once that's merged I will use that and add the try/finally to this PR.
I think this could be a good followup PR - feel free to create a new issue and assign it to yourself. I would suggest subclassing CheckpointConfig for Ray Train v2 and adding the knob there.
I think this would also be a good followup PR - feel free to create a GitHub issue and assign it to yourself.
Reports should happen in order of submission per worker but each report also forms a barrier across all workers (see the warning in https://docs.ray.io/en/latest/train/api/doc/ray.train.report.html). What kind of documentation did you have in mind? I think https://github.com/ray-project/ray/blob/master/python/ray/train/v2/api/train_fn_utils.py#L36 mentions something to that effect but feel free to suggest modifications. Can you elaborate on the test? I think https://github.com/ray-project/ray/pull/55637/files#diff-7df817b7f2f904c441481b97cc76db79e527b265b8b7a8fadbc78e722f16f208R147 tries to do something similar but let me know if you have any suggestions.
I think this could also be worth filing a GitHub issue and assigning to yourself. My only concerns are to make sure it doesn't slow down training and that it isn't too noisy e.g. maybe we can toggle it with an environment variable.
After some more thought I think a ThreadPool would indeed be better - let me change the PR accordingly. |
Signed-off-by: Timothy Seah <[email protected]>
The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (#55637). --------- Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
…ject#55756) The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (ray-project#55637). --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Masahiro Tanaka <[email protected]>
…ject#55756) The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (ray-project#55637). --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Masahiro Tanaka <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
daac015 to
4a70d28
Compare
Signed-off-by: Timothy Seah <[email protected]>
Signed-off-by: Timothy Seah <[email protected]>
…ject#55756) The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (ray-project#55637). --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: jugalshah291 <[email protected]>
…ay-project#55637) Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD. * Introduce per-worker checkpoint counters to preserve report order. * Use a thread pool to limit concurrent uploads and avoid OOM. * Wrap the training function to wait for pending uploads before exiting. * Add delete_local_checkpoint_after_upload to control temporary local directory cleanup. --------- Signed-off-by: Timothy Seah <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: zac <[email protected]>
…ay-project#55637) Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD. * Introduce per-worker checkpoint counters to preserve report order. * Use a thread pool to limit concurrent uploads and avoid OOM. * Wrap the training function to wait for pending uploads before exiting. * Add delete_local_checkpoint_after_upload to control temporary local directory cleanup. --------- Signed-off-by: Timothy Seah <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Marco Stephan <[email protected]>
#56208) After #55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Seiji Eicher <[email protected]>
…ject#55756) The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (ray-project#55637). --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Douglas Strodtman <[email protected]>
…55637) Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD. * Introduce per-worker checkpoint counters to preserve report order. * Use a thread pool to limit concurrent uploads and avoid OOM. * Wrap the training function to wait for pending uploads before exiting. * Add delete_local_checkpoint_after_upload to control temporary local directory cleanup. --------- Signed-off-by: Timothy Seah <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Douglas Strodtman <[email protected]>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Douglas Strodtman <[email protected]>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]>
Original PR #55637 by TimothySeah Original: ray-project/ray#55637
… to ray.train.report Merged from original PR #55637 Original: ray-project/ray#55637
Original PR #55637 by TimothySeah Original: ray-project/ray#55637
… to ray.train.report Merged from original PR #55637 Original: ray-project/ray#55637
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Josh Kodi <[email protected]>
…ay-project#55637) Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD. * Introduce per-worker checkpoint counters to preserve report order. * Use a thread pool to limit concurrent uploads and avoid OOM. * Wrap the training function to wait for pending uploads before exiting. * Add delete_local_checkpoint_after_upload to control temporary local directory cleanup. --------- Signed-off-by: Timothy Seah <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]>
Original PR #55637 by TimothySeah Original: ray-project/ray#55637
… to ray.train.report Merged from original PR #55637 Original: ray-project/ray#55637
…ject#55756) The `ThreadRunner` is an abstraction used by Ray Train to capture errors raised by the training function so they can be polled by the Ray Train controller. This PR extends the `ThreadRunner` to also capture errors raised by threads created by the training function e.g. async checkpoint upload threads (ray-project#55637). --------- Signed-off-by: Timothy Seah <[email protected]>
…ay-project#55637) Implement async checkpoint uploads in ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD. * Introduce per-worker checkpoint counters to preserve report order. * Use a thread pool to limit concurrent uploads and avoid OOM. * Wrap the training function to wait for pending uploads before exiting. * Add delete_local_checkpoint_after_upload to control temporary local directory cleanup. --------- Signed-off-by: Timothy Seah <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Aydin Abiar <[email protected]>
ray-project#56208) After ray-project#55637, `ray.train.report` will allow users to upload checkpoints from disk to remote storage asynchronously. If they want to use framework-specific async checkpointing like `torch.async_save`, they can manage `torch.async_save` themselves and then call `ray.train.report(…checkpoint_upload_mode=CheckpointUploadMode.NO_UPLOAD)`. However, it would also be nice to allow `ray.train.report` to handle rate limiting and report ordering for framework-specific async checkpointing as well. This PR achieves this by exposing a `checkpoint_upload_function` argument that can replace the `persist_current_checkpoint` call. --------- Signed-off-by: Timothy Seah <[email protected]> Signed-off-by: Future-Outlier <[email protected]>
Summary
Implement async checkpoint uploads in
ray.train.report(..., checkpoint_upload_mode), supporting SYNC (default), ASYNC, and NO_UPLOAD.delete_local_checkpoint_after_uploadto control temporary local directory cleanup.Implementation Summary
This PR implements async checkpointing by
checkpoint_upload_modetoray.train.reportwith three optionsnum_reported_checkpointsandnum_attempted_reported_checkpointscounters on theTrainContextcheckpoint_upload_modethe different Ray Train workers are doing, we want to upload checkpoints in the order they wereray.train.reported. Therefore, each Ray Train worker waits for its turn (num_reported_checkpoints == current_report_attempt_number - 1) before adding its checkpoint to the result queue.ThreadPoolExecutorto guard against adding too many checkpoint upload threads.run_train_fnto wrap thetrain_fnintrain_fn_that_waits_for_threadsbecause otherwise, we could be in the following situation: 1) train function exits with pending report threads and worker status is finished 2) controller sees finished status and shuts down worker group 3) result.fit does not return all the reported checkpoints/metricsThreadRunnerbut "wait for threads" as a wrapper function because in the former case, that is the cleanest way for a nested thread to cause the entire worker to exit early, but in this case, the target function is able to wait for the threads that it creates without complicating theThreadRunnerabstraction.A few other notes:
Checkpoints(instead ofCheckpointObjectRefs) to the result queue because:ObjectRefapproach, the controller would create a Ray task that updates controller state. This "driver creates task that updates driver" pattern is unwieldy to implement.API Changes
This PR's only API changes are adding the following two arguments to
ray.train.report:checkpoint_upload_mode:delete_local_checkpoint_after_upload: Whether to delete the checkpoint after uploading it. Users generally won't need to set this since each checkpoint upload mode has its own default:tempfiletempfile- see previous section for explanationHere's a simple example of this API in action:
Testing
Looks like async reporting is indeed faster with the same loss on the pytorch ray train example: https://docs.ray.io/en/latest/train/getting-started-pytorch.html
Sync mode
3m3s
Async mode
2m57s with only ~0.22s blocking time when waiting for the last checkpoint upload: