Skip to content

feat: Apple Silicon (MPS) device support#4977

Closed
SAY-5 wants to merge 5 commits intounslothai:mainfrom
SAY-5:feat/apple-silicon-support
Closed

feat: Apple Silicon (MPS) device support#4977
SAY-5 wants to merge 5 commits intounslothai:mainfrom
SAY-5:feat/apple-silicon-support

Conversation

@SAY-5
Copy link
Copy Markdown

@SAY-5 SAY-5 commented Apr 12, 2026

Adds MPS (Metal Performance Shaders) support for Apple Silicon Macs.

device_type.py now detects MPS via torch.backends.mps. The rest of the changes make the import chain survive on a machine with no CUDA/Triton/bitsandbytes — Triton kernels get None stubs, bitsandbytes is skipped, and AMP decorators target the mps device type. Memory reporting uses recommended_max_working_set_size() and bfloat16 support is probed at runtime rather than hardcoded.

16-bit finetuning should work on MPS. 4-bit QLoRA still needs bitsandbytes to ship MPS support upstream.

Closes #4

- device_type.py: MPS detection via torch.backends.mps
- models/_utils.py: MPS branches for amp, bfloat16, memory
- kernels/utils.py: Guard Triton/bitsandbytes with MPS stubs
- __init__.py: MPS branch in device setup

All 4 files pass py_compile syntax check. Closes unslothai#4
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Apple Silicon (MPS) devices, adding device detection, conditional imports for Triton and bitsandbytes, and stubs for stream management. Review feedback highlights that the memory calculation for MPS should utilize torch.mps.recommended_max_working_set_size() instead of driver_allocated_memory(), and that AMP decorators should target the mps device type rather than cpu to ensure proper autocast functionality during 16-bit finetuning.

Comment thread unsloth/models/_utils.py Outdated
Comment on lines +1517 to +1521
elif DEVICE_TYPE == "mps":
# MPS shares unified memory; report recommended allocator limit
total_memory = torch.mps.driver_allocated_memory() or (
int(os.popen("sysctl -n hw.memsize").read().strip())
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for calculating total_memory on MPS is flawed. torch.mps.driver_allocated_memory() returns the amount of memory currently in use by the MPS driver, not the total available capacity. Using it with an or fallback means that as soon as any memory is allocated (even a few MBs), the reported "total memory" will be that small allocated amount instead of the actual system limit.

On Apple Silicon, the appropriate API to determine the maximum memory the GPU can safely use is torch.mps.recommended_max_working_set_size(). This provides a much more accurate representation of the available "VRAM" in a unified memory architecture.

    elif DEVICE_TYPE == "mps":
        # MPS shares unified memory; report recommended allocator limit
        total_memory = torch.mps.recommended_max_working_set_size()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

Comment thread unsloth/kernels/utils.py Outdated
Comment on lines +54 to +57

if Version(torch.__version__) < Version("2.4.0"):
if DEVICE_TYPE == "mps":
# MPS does not support autocast custom_fwd/bwd; use CPU fallback
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using device_type="cpu" for the AMP decorators on MPS will effectively disable autocast for any operations performed on MPS tensors within the decorated functions, as the CPU autocast context does not track MPS device operations. Since PyTorch (2.0+) supports MPS autocast, you should use device_type="mps" to ensure that 16-bit finetuning works correctly with these decorators.

Suggested change
if Version(torch.__version__) < Version("2.4.0"):
if DEVICE_TYPE == "mps":
# MPS does not support autocast custom_fwd/bwd; use CPU fallback
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu")
if DEVICE_TYPE == "mps":
# Use mps device type for autocast decorators
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "mps")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "mps")

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

Comment thread unsloth/models/_utils.py Outdated
Comment on lines +877 to +880
elif DEVICE_TYPE == "mps":
# MPS does not support autocast custom_fwd/bwd; use CPU fallback for amp decorators
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cpu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using device_type="cpu" for the AMP decorators on MPS will effectively disable autocast for any operations performed on MPS tensors within the decorated functions. Since PyTorch supports MPS autocast, you should use device_type="mps" to ensure that 16-bit finetuning works correctly with these decorators.

Suggested change
elif DEVICE_TYPE == "mps":
# MPS does not support autocast custom_fwd/bwd; use CPU fallback for amp decorators
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cpu")
elif DEVICE_TYPE == "mps":
# Use mps device type for autocast decorators
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "mps")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "mps")

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d1832d3e27

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/device_type.py
Comment on lines +44 to +45
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard MPS detection until optional deps are import-safe

Returning "mps" here lets initialization proceed into model imports, but those imports still include hard dependencies on CUDA-only libraries (for example unsloth/models/llama.py imports triton and unsloth/models/_utils.py imports bitsandbytes). On Apple Silicon setups where those packages are not installed, import unsloth now fails with ModuleNotFoundError, so the new MPS path does not actually initialize. Please either make those imports conditional first or gate MPS selection until dependency checks pass.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

Comment thread unsloth/kernels/utils.py Outdated
MAX_FUSED_SIZE: int = 65536
import functools

from .fp8 import weight_dequant, fp8_linear
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip fp8 Triton import when initializing on MPS

This import is unconditional, but unsloth/kernels/fp8.py immediately imports triton; in the same change, the MPS branch explicitly treats Triton as unavailable. On Macs without Triton installed, importing unsloth.kernels.utils crashes before fallback logic can run, which breaks MPS startup even for non-FP8 use cases. Make FP8 loading conditional or lazy on non-MPS paths.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

- Use torch.mps.recommended_max_working_set_size() for memory
  calculation instead of driver_allocated_memory() (which reports
  current allocation, not total available)
- Use device_type="mps" for AMP decorators instead of "cpu"
  (MPS supports autocast since PyTorch 2.3+)
- Guard fp8 Triton import with MPS check (fp8.py imports triton
  unconditionally)

All 4 files pass py_compile syntax check.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f8a41529d8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/_utils.py Outdated
SUPPORTS_BFLOAT16 = True
elif DEVICE_TYPE == "mps":
# Apple Silicon (M1+) natively supports bfloat16
SUPPORTS_BFLOAT16 = True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate MPS bfloat16 on real runtime support

Setting SUPPORTS_BFLOAT16 = True unconditionally for MPS makes is_bfloat16_supported() report support even on Apple environments where PyTorch MPS bf16 is unavailable (for example older macOS/runtime combinations), so model loading can default to torch.bfloat16 and then fail at runtime. This should be derived from an actual MPS bf16 capability check instead of hardcoding True.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit f8a4152 — updated to use recommended_max_working_set_size() for memory, device_type="mps" for AMP decorators, and guarded fp8 Triton import.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5e1f2e7 — now probes MPS bfloat16 support at runtime by creating a test tensor, instead of hardcoding True. Falls back to False on older macOS/PyTorch combinations.

SAY-5 and others added 2 commits April 11, 2026 19:18
Instead of hardcoding SUPPORTS_BFLOAT16=True for MPS, probe at
runtime by creating a bfloat16 tensor on MPS. This handles older
macOS/PyTorch combinations where bf16 isn't available.

Addresses Codex review feedback.
@rolandtannous rolandtannous marked this pull request as draft April 12, 2026 09:38
@rolandtannous
Copy link
Copy Markdown
Collaborator

Since we're in the final phases of our own inhouse MLX integration, I have to close this for the benefit of the work that's already been done and will be released soon. Thank you for the effort though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Apple Silicon Support

2 participants