feat: Apple Silicon (MPS) device support#4977
Conversation
- device_type.py: MPS detection via torch.backends.mps - models/_utils.py: MPS branches for amp, bfloat16, memory - kernels/utils.py: Guard Triton/bitsandbytes with MPS stubs - __init__.py: MPS branch in device setup All 4 files pass py_compile syntax check. Closes unslothai#4
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces support for Apple Silicon (MPS) devices, adding device detection, conditional imports for Triton and bitsandbytes, and stubs for stream management. Review feedback highlights that the memory calculation for MPS should utilize torch.mps.recommended_max_working_set_size() instead of driver_allocated_memory(), and that AMP decorators should target the mps device type rather than cpu to ensure proper autocast functionality during 16-bit finetuning.
| elif DEVICE_TYPE == "mps": | ||
| # MPS shares unified memory; report recommended allocator limit | ||
| total_memory = torch.mps.driver_allocated_memory() or ( | ||
| int(os.popen("sysctl -n hw.memsize").read().strip()) | ||
| ) |
There was a problem hiding this comment.
The logic for calculating total_memory on MPS is flawed. torch.mps.driver_allocated_memory() returns the amount of memory currently in use by the MPS driver, not the total available capacity. Using it with an or fallback means that as soon as any memory is allocated (even a few MBs), the reported "total memory" will be that small allocated amount instead of the actual system limit.
On Apple Silicon, the appropriate API to determine the maximum memory the GPU can safely use is torch.mps.recommended_max_working_set_size(). This provides a much more accurate representation of the available "VRAM" in a unified memory architecture.
elif DEVICE_TYPE == "mps":
# MPS shares unified memory; report recommended allocator limit
total_memory = torch.mps.recommended_max_working_set_size()There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
|
|
||
| if Version(torch.__version__) < Version("2.4.0"): | ||
| if DEVICE_TYPE == "mps": | ||
| # MPS does not support autocast custom_fwd/bwd; use CPU fallback | ||
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu") |
There was a problem hiding this comment.
Using device_type="cpu" for the AMP decorators on MPS will effectively disable autocast for any operations performed on MPS tensors within the decorated functions, as the CPU autocast context does not track MPS device operations. Since PyTorch (2.0+) supports MPS autocast, you should use device_type="mps" to ensure that 16-bit finetuning works correctly with these decorators.
| if Version(torch.__version__) < Version("2.4.0"): | |
| if DEVICE_TYPE == "mps": | |
| # MPS does not support autocast custom_fwd/bwd; use CPU fallback | |
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu") | |
| if DEVICE_TYPE == "mps": | |
| # Use mps device type for autocast decorators | |
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "mps") | |
| torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "mps") |
There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
| elif DEVICE_TYPE == "mps": | ||
| # MPS does not support autocast custom_fwd/bwd; use CPU fallback for amp decorators | ||
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu") | ||
| torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cpu") |
There was a problem hiding this comment.
Using device_type="cpu" for the AMP decorators on MPS will effectively disable autocast for any operations performed on MPS tensors within the decorated functions. Since PyTorch supports MPS autocast, you should use device_type="mps" to ensure that 16-bit finetuning works correctly with these decorators.
| elif DEVICE_TYPE == "mps": | |
| # MPS does not support autocast custom_fwd/bwd; use CPU fallback for amp decorators | |
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu") | |
| torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cpu") | |
| elif DEVICE_TYPE == "mps": | |
| # Use mps device type for autocast decorators | |
| torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "mps") | |
| torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "mps") |
There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d1832d3e27
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | ||
| return "mps" |
There was a problem hiding this comment.
Guard MPS detection until optional deps are import-safe
Returning "mps" here lets initialization proceed into model imports, but those imports still include hard dependencies on CUDA-only libraries (for example unsloth/models/llama.py imports triton and unsloth/models/_utils.py imports bitsandbytes). On Apple Silicon setups where those packages are not installed, import unsloth now fails with ModuleNotFoundError, so the new MPS path does not actually initialize. Please either make those imports conditional first or gate MPS selection until dependency checks pass.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
| MAX_FUSED_SIZE: int = 65536 | ||
| import functools | ||
|
|
||
| from .fp8 import weight_dequant, fp8_linear |
There was a problem hiding this comment.
Skip fp8 Triton import when initializing on MPS
This import is unconditional, but unsloth/kernels/fp8.py immediately imports triton; in the same change, the MPS branch explicitly treats Triton as unavailable. On Macs without Triton installed, importing unsloth.kernels.utils crashes before fallback logic can run, which breaks MPS startup even for non-FP8 use cases. Make FP8 loading conditional or lazy on non-MPS paths.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
- Use torch.mps.recommended_max_working_set_size() for memory calculation instead of driver_allocated_memory() (which reports current allocation, not total available) - Use device_type="mps" for AMP decorators instead of "cpu" (MPS supports autocast since PyTorch 2.3+) - Guard fp8 Triton import with MPS check (fp8.py imports triton unconditionally) All 4 files pass py_compile syntax check.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f8a41529d8
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| SUPPORTS_BFLOAT16 = True | ||
| elif DEVICE_TYPE == "mps": | ||
| # Apple Silicon (M1+) natively supports bfloat16 | ||
| SUPPORTS_BFLOAT16 = True |
There was a problem hiding this comment.
Gate MPS bfloat16 on real runtime support
Setting SUPPORTS_BFLOAT16 = True unconditionally for MPS makes is_bfloat16_supported() report support even on Apple environments where PyTorch MPS bf16 is unavailable (for example older macOS/runtime combinations), so model loading can default to torch.bfloat16 and then fail at runtime. This should be derived from an actual MPS bf16 capability check instead of hardcoding True.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.
There was a problem hiding this comment.
Fixed in 5e1f2e7 — now probes MPS bfloat16 support at runtime by creating a test tensor, instead of hardcoding True. Falls back to False on older macOS/PyTorch combinations.
Instead of hardcoding SUPPORTS_BFLOAT16=True for MPS, probe at runtime by creating a bfloat16 tensor on MPS. This handles older macOS/PyTorch combinations where bf16 isn't available. Addresses Codex review feedback.
for more information, see https://pre-commit.ci
|
Since we're in the final phases of our own inhouse MLX integration, I have to close this for the benefit of the work that's already been done and will be released soon. Thank you for the effort though. |
Adds MPS (Metal Performance Shaders) support for Apple Silicon Macs.
device_type.pynow detects MPS viatorch.backends.mps. The rest of the changes make the import chain survive on a machine with no CUDA/Triton/bitsandbytes — Triton kernels getNonestubs, bitsandbytes is skipped, and AMP decorators target thempsdevice type. Memory reporting usesrecommended_max_working_set_size()and bfloat16 support is probed at runtime rather than hardcoded.16-bit finetuning should work on MPS. 4-bit QLoRA still needs bitsandbytes to ship MPS support upstream.
Closes #4