Skip to content

[WIP] Apple Silicon (MPS/Metal) Support#3950

Draft
Wilbatronic wants to merge 845 commits intounslothai:mainfrom
Wilbatronic:apple-silicon-support
Draft

[WIP] Apple Silicon (MPS/Metal) Support#3950
Wilbatronic wants to merge 845 commits intounslothai:mainfrom
Wilbatronic:apple-silicon-support

Conversation

@Wilbatronic
Copy link

This PR introduces high-performance Apple Silicon support for Unsloth. The goal is to allow Mac users (M1/M2/M3/M4) to fine-tune and run inference on 7B+ models with performance parity to entry-level CUDA hardware, leveraging Apple's Unified Memory and Metal architecture.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Wilbatronic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands Unsloth's hardware compatibility by introducing initial support for Apple Silicon (MPS/Metal) devices. The changes enable Mac users with M-series chips to leverage their hardware for fine-tuning and inference of large language models, aiming for performance comparable to entry-level CUDA GPUs. This is achieved through a combination of MPS-specific kernel implementations, intelligent device detection, and conditional module loading to adapt to the unique architecture of Apple Silicon, where libraries like Triton and bitsandbytes are not natively supported.

Highlights

  • Apple Silicon (MPS/Metal) Support: Introduced core support for Apple Silicon (MPS/Metal) devices, enabling Mac users to leverage their M-series chips for high-performance fine-tuning and inference of large language models.
  • PyTorch-Native Kernel Fallbacks: Implemented PyTorch-native fallback kernels for critical operations such as RMS LayerNorm, LayerNorm, RoPE embedding, Cross-Entropy Loss, SwiGLU, GEGLU, and LoRA operations. These replace Triton-based kernels on MPS, ensuring functionality and numerical parity.
  • Intelligent Device Detection and Capabilities: Added robust MPS device detection, including checks for bfloat16 support and unified memory information. The system now conditionally loads modules and functionalities, disabling unsupported libraries like Triton and bitsandbytes for MPS.
  • Quantization Handling and Warnings: Disabled bitsandbytes imports and functionalities for MPS devices, as they are not supported. Graceful fallbacks and user warnings are now in place for attempts to load quantized models on MPS, guiding users towards 16-bit models for optimal performance.
  • Comprehensive Testing Suite: Developed a new suite of unit and integration tests specifically for MPS. These tests verify device detection, numerical parity of MPS fallback kernels, and overall system integration, with conditional skipping on non-MPS hardware.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an impressive and comprehensive pull request that adds Apple Silicon (MPS) support to Unsloth. The changes are well-structured, with clear separation of MPS-specific code, PyTorch-native fallbacks for Triton kernels, and a clean dispatching mechanism. The addition of extensive tests for numerical parity, integration, and sanity checks is commendable and crucial for ensuring correctness on the new backend.

I've identified a critical bug in the RoPE kernel that would lead to incorrect tensor shapes, a high-severity issue in the device stream handling logic, and some dead code that should be removed for better maintainability. After addressing these points, this PR will be a fantastic addition to the project, significantly expanding its user base to Mac users.

Comment on lines 291 to 293

if DEVICE_TYPE == "mps" and USE_MPS_FALLBACK:
from .mps.rope_embedding import mps_rope_embedding_qk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a shape mismatch bug in the MPS fallback for RoPE. The mps_rope_embedding_qk function returns tensors with shape (batch, n_heads, seq_len, head_dim), but the caller of fast_rope_embedding expects the original shape (batch, seq_len, n_heads, head_dim).

The output tensors need to be transposed back to the expected shape before being returned, just like it's done for the Triton path. Without this, downstream operations will fail due to incorrect tensor shapes.

        q_out, k_out = mps_rope_embedding_qk(Q.transpose(1, 2).contiguous(), K.transpose(1, 2).contiguous(), cos, sin)
        return q_out.transpose(1, 2), k_out.transpose(1, 2)
Comment on lines +260 to 263
cgemm_4bit_inference_naive_fp16 = None
cgemm_4bit_inference_naive_bf16 = None
else:
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a bug in the torch_device_stream definition. torch.cuda.current_stream() is being called, which returns a stream object, whereas the other branches return a function that returns a stream object. This will cause an error on CUDA devices when torch_device_stream is used.

The expression is also quite complex and hard to read. I suggest refactoring it into a simple if/elif/else block for clarity and to fix the bug.

Suggested change
cgemm_4bit_inference_naive_fp16 = None
cgemm_4bit_inference_naive_bf16 = None
else:
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
if DEVICE_TYPE == "xpu":
torch_device_stream = torch.xpu.current_stream
elif DEVICE_TYPE == "mps":
torch_device_stream = lambda: None
else:
torch_device_stream = torch.cuda.current_stream
Comment on lines 36 to 65

return out


class MPSLoRA_MLP(torch.autograd.Function):
@staticmethod
def forward(
ctx,
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
_forward_function,
):
# Forward pass using MPS-compatible operations
e = mps_matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
g = mps_matmul_lora(X, upW, upW_quant, upA, upB, upS)
h = _forward_function(e, g)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MPSLoRA_MLP class appears to be unused. The dispatch logic in unsloth/kernels/mps/dispatch.py for dispatch_lora_mlp_swiglu calls mps_apply_lora_mlp_swiglu, which uses a direct PyTorch-native implementation instead of this torch.autograd.Function.

Since the backward method is not implemented and raises a NotImplementedError, and the class itself is not being used, it would be best to remove it to avoid confusion and dead code in the repository.

@Wilbatronic Wilbatronic force-pushed the apple-silicon-support branch 2 times, most recently from 4df41f8 to 17f4ddd Compare February 6, 2026 20:47
Replace use of mx.nn.silu with an explicit SiLU expression (gate * sigmoid(gate)) in benchmark_training to ensure the activation is computed as x * sigmoid(x).

Add a mock process_vision_info function to unsloth_zoo.vision_utils that returns (None, None) to provide MPS-compatible stubs and avoid attribute errors when the real vision utilities are unavailable.
Compute SiLU explicitly in benchmarks/mlx/benchmark_training.py by replacing mx.nn.sigmoid(gate) with 1/(1+exp(-gate)) for compatibility/consistency. In unsloth/patches.py add mocked unsloth_zoo.compiler functions: get_transformers_model_type (returns "llama") and unsloth_compile_transformers (no-op) to provide MPS-compatible stubs and avoid import/runtime errors.
Replace mx.nn.silu usage in benchmark_mlp_training with an explicit SiLU computation (h_pre = x @ w1 + b1; h = h_pre * sigmoid(h_pre) * (x @ w3)) to make the eager benchmark implementation explicit. Also add mock_temp.TEMPORARY_PATCHES = [] when mocking unsloth_zoo.temporary_patches and register it in sys.modules to avoid missing-attribute errors when that module is imported (improves compatibility for MPS/patch handling).
Add a new benchmark_finetune_training to benchmarks/mlx/benchmark_training.py that runs a full fine-tuning loop on a small Llama-style model (eager vs compiled) and registers it under the run_all_benchmarks dispatcher. Also patch unsloth/patches.py to provide two additional mocked helpers on unsloth_zoo.hf_utils (add_dtype_kwargs -> {} and fix_lora_auto_mapping -> None) for MPS compatibility.
Import mlx.optimizers as optim and update optimizer construction to use optim.SGD instead of mx.optimizers.SGD. Add a mock unsloth_zoo.temporary_patches.common module providing a torch_compile identity stub to avoid import/attribute errors when patching unsloth_zoo in environments without torch.compile (e.g., MPS/no-compile builds).
Remove dependency on mlx.optimizers in the training benchmark and replace per-parameter optimizer objects with a simple learning-rate-based gradient update (compute grad with mx.grad and apply p = p - lr * grad). This simplifies init/update logic for the finetune benchmark.

Also extend the unsloth_zoo loss mock to include no-op patch_loss_functions and post_patch_loss_function attributes to avoid attribute errors when patching for MPS compatibility.
Integrate Unsloth MLX and custom Metal kernels into the training benchmark suite: attempt imports, set availability flags, and add benchmark cases (and compiled variants) that exercise mlx wrappers and Metal kernels alongside existing mx.fast ops. Also extend the compiled fused MLP in unsloth/kernels/mlx/fast_lora.py to accept optional gate_multiplier and down_multiplier parameters and apply them when provided, enabling runtime scaling of GELU gate and output (useful for LoRA-style multipliers).
Wrap inspect.getsource calls in try-except to handle MockModule
objects that don't have real source code (e.g., when modules are
mocked on MPS/MLX platforms).
- Improve MLX loader to use debug logging instead of warning for HuggingFace format weights
- Fix PyTorch memory tracking to use unified memory delta (CPU memory) on Apple Silicon
- Add get_memory_delta helper that reports CPU delta as GPU delta when GPU tracking returns 0
- This gives accurate memory usage since Apple Silicon uses unified memory
- Add process-specific memory tracking using resource.ru_maxrss and psutil
- Add materialize_model_on_gpu() to force weight loading into unified memory
- Update MemoryStats to track peak memory separately
- Fix get_memory_delta to properly handle unified memory semantics
- Add dummy forward pass after model loading to ensure accurate GPU memory measurement
…port

This commit adds:
-  kernel for MLX with mx.fast.cce_loss support.
- Autograd-aware  bridge for PyTorch on MPS.
- Native MLX Llama support for CCE.
- Updated  as a native MLX benchmark.
- Patcher support for CCE in  via  env var.
The error 'Attempting to eval an array without a primitive' occurred because:
1. mx.compile requires all inputs/outputs to be properly captured
2. model.parameters() and optimizer.state may contain arrays not in the computation graph

Fixed by removing @mx.compile decorator and evaluating loss, parameters,
and optimizer state separately after the gradient update.
- Print every step for first 10 steps, then every 10
- Show step time and rolling average
- Print model config at start
- Print total time summary at end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant