[WIP] Apple Silicon (MPS/Metal) Support#3950
[WIP] Apple Silicon (MPS/Metal) Support#3950Wilbatronic wants to merge 845 commits intounslothai:mainfrom
Conversation
Summary of ChangesHello @Wilbatronic, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands Unsloth's hardware compatibility by introducing initial support for Apple Silicon (MPS/Metal) devices. The changes enable Mac users with M-series chips to leverage their hardware for fine-tuning and inference of large language models, aiming for performance comparable to entry-level CUDA GPUs. This is achieved through a combination of MPS-specific kernel implementations, intelligent device detection, and conditional module loading to adapt to the unique architecture of Apple Silicon, where libraries like Triton and bitsandbytes are not natively supported. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This is an impressive and comprehensive pull request that adds Apple Silicon (MPS) support to Unsloth. The changes are well-structured, with clear separation of MPS-specific code, PyTorch-native fallbacks for Triton kernels, and a clean dispatching mechanism. The addition of extensive tests for numerical parity, integration, and sanity checks is commendable and crucial for ensuring correctness on the new backend.
I've identified a critical bug in the RoPE kernel that would lead to incorrect tensor shapes, a high-severity issue in the device stream handling logic, and some dead code that should be removed for better maintainability. After addressing these points, this PR will be a fantastic addition to the project, significantly expanding its user base to Mac users.
unsloth/kernels/rope_embedding.py
Outdated
|
|
||
| if DEVICE_TYPE == "mps" and USE_MPS_FALLBACK: | ||
| from .mps.rope_embedding import mps_rope_embedding_qk |
There was a problem hiding this comment.
There's a shape mismatch bug in the MPS fallback for RoPE. The mps_rope_embedding_qk function returns tensors with shape (batch, n_heads, seq_len, head_dim), but the caller of fast_rope_embedding expects the original shape (batch, seq_len, n_heads, head_dim).
The output tensors need to be transposed back to the expected shape before being returned, just like it's done for the Triton path. Without this, downstream operations will fail due to incorrect tensor shapes.
q_out, k_out = mps_rope_embedding_qk(Q.transpose(1, 2).contiguous(), K.transpose(1, 2).contiguous(), cos, sin)
return q_out.transpose(1, 2), k_out.transpose(1, 2)| cgemm_4bit_inference_naive_fp16 = None | ||
| cgemm_4bit_inference_naive_bf16 = None | ||
| else: | ||
| cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 |
There was a problem hiding this comment.
There's a bug in the torch_device_stream definition. torch.cuda.current_stream() is being called, which returns a stream object, whereas the other branches return a function that returns a stream object. This will cause an error on CUDA devices when torch_device_stream is used.
The expression is also quite complex and hard to read. I suggest refactoring it into a simple if/elif/else block for clarity and to fix the bug.
| cgemm_4bit_inference_naive_fp16 = None | |
| cgemm_4bit_inference_naive_bf16 = None | |
| else: | |
| cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 | |
| if DEVICE_TYPE == "xpu": | |
| torch_device_stream = torch.xpu.current_stream | |
| elif DEVICE_TYPE == "mps": | |
| torch_device_stream = lambda: None | |
| else: | |
| torch_device_stream = torch.cuda.current_stream |
|
|
||
| return out | ||
|
|
||
|
|
||
| class MPSLoRA_MLP(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx, | ||
| X, | ||
| gateW, | ||
| gateW_quant, | ||
| gateA, | ||
| gateB, | ||
| gateS, | ||
| upW, | ||
| upW_quant, | ||
| upA, | ||
| upB, | ||
| upS, | ||
| downW, | ||
| downW_quant, | ||
| downA, | ||
| downB, | ||
| downS, | ||
| _forward_function, | ||
| ): | ||
| # Forward pass using MPS-compatible operations | ||
| e = mps_matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS) | ||
| g = mps_matmul_lora(X, upW, upW_quant, upA, upB, upS) | ||
| h = _forward_function(e, g) |
There was a problem hiding this comment.
The MPSLoRA_MLP class appears to be unused. The dispatch logic in unsloth/kernels/mps/dispatch.py for dispatch_lora_mlp_swiglu calls mps_apply_lora_mlp_swiglu, which uses a direct PyTorch-native implementation instead of this torch.autograd.Function.
Since the backward method is not implemented and raises a NotImplementedError, and the class itself is not being used, it would be best to remove it to avoid confusion and dead code in the repository.
4df41f8 to
17f4ddd
Compare
Replace use of mx.nn.silu with an explicit SiLU expression (gate * sigmoid(gate)) in benchmark_training to ensure the activation is computed as x * sigmoid(x). Add a mock process_vision_info function to unsloth_zoo.vision_utils that returns (None, None) to provide MPS-compatible stubs and avoid attribute errors when the real vision utilities are unavailable.
Compute SiLU explicitly in benchmarks/mlx/benchmark_training.py by replacing mx.nn.sigmoid(gate) with 1/(1+exp(-gate)) for compatibility/consistency. In unsloth/patches.py add mocked unsloth_zoo.compiler functions: get_transformers_model_type (returns "llama") and unsloth_compile_transformers (no-op) to provide MPS-compatible stubs and avoid import/runtime errors.
Replace mx.nn.silu usage in benchmark_mlp_training with an explicit SiLU computation (h_pre = x @ w1 + b1; h = h_pre * sigmoid(h_pre) * (x @ w3)) to make the eager benchmark implementation explicit. Also add mock_temp.TEMPORARY_PATCHES = [] when mocking unsloth_zoo.temporary_patches and register it in sys.modules to avoid missing-attribute errors when that module is imported (improves compatibility for MPS/patch handling).
Add a new benchmark_finetune_training to benchmarks/mlx/benchmark_training.py that runs a full fine-tuning loop on a small Llama-style model (eager vs compiled) and registers it under the run_all_benchmarks dispatcher. Also patch unsloth/patches.py to provide two additional mocked helpers on unsloth_zoo.hf_utils (add_dtype_kwargs -> {} and fix_lora_auto_mapping -> None) for MPS compatibility.
Import mlx.optimizers as optim and update optimizer construction to use optim.SGD instead of mx.optimizers.SGD. Add a mock unsloth_zoo.temporary_patches.common module providing a torch_compile identity stub to avoid import/attribute errors when patching unsloth_zoo in environments without torch.compile (e.g., MPS/no-compile builds).
Remove dependency on mlx.optimizers in the training benchmark and replace per-parameter optimizer objects with a simple learning-rate-based gradient update (compute grad with mx.grad and apply p = p - lr * grad). This simplifies init/update logic for the finetune benchmark. Also extend the unsloth_zoo loss mock to include no-op patch_loss_functions and post_patch_loss_function attributes to avoid attribute errors when patching for MPS compatibility.
… enabled float16 by default
Integrate Unsloth MLX and custom Metal kernels into the training benchmark suite: attempt imports, set availability flags, and add benchmark cases (and compiled variants) that exercise mlx wrappers and Metal kernels alongside existing mx.fast ops. Also extend the compiled fused MLP in unsloth/kernels/mlx/fast_lora.py to accept optional gate_multiplier and down_multiplier parameters and apply them when provided, enabling runtime scaling of GELU gate and output (useful for LoRA-style multipliers).
Wrap inspect.getsource calls in try-except to handle MockModule objects that don't have real source code (e.g., when modules are mocked on MPS/MLX platforms).
- Improve MLX loader to use debug logging instead of warning for HuggingFace format weights - Fix PyTorch memory tracking to use unified memory delta (CPU memory) on Apple Silicon - Add get_memory_delta helper that reports CPU delta as GPU delta when GPU tracking returns 0 - This gives accurate memory usage since Apple Silicon uses unified memory
- Add process-specific memory tracking using resource.ru_maxrss and psutil - Add materialize_model_on_gpu() to force weight loading into unified memory - Update MemoryStats to track peak memory separately - Fix get_memory_delta to properly handle unified memory semantics - Add dummy forward pass after model loading to ensure accurate GPU memory measurement
…port This commit adds: - kernel for MLX with mx.fast.cce_loss support. - Autograd-aware bridge for PyTorch on MPS. - Native MLX Llama support for CCE. - Updated as a native MLX benchmark. - Patcher support for CCE in via env var.
The error 'Attempting to eval an array without a primitive' occurred because: 1. mx.compile requires all inputs/outputs to be properly captured 2. model.parameters() and optimizer.state may contain arrays not in the computation graph Fixed by removing @mx.compile decorator and evaluating loss, parameters, and optimizer state separately after the gradient update.
- Print every step for first 10 steps, then every 10 - Show step time and rolling average - Print model config at start - Print total time summary at end
This PR introduces high-performance Apple Silicon support for Unsloth. The goal is to allow Mac users (M1/M2/M3/M4) to fine-tune and run inference on 7B+ models with performance parity to entry-level CUDA hardware, leveraging Apple's Unified Memory and Metal architecture.