Skip to content

Speedrun Submission: MuonH Optimizer#498

Open
dangxingyu wants to merge 16 commits intokarpathy:masterfrom
dangxingyu:muonh
Open

Speedrun Submission: MuonH Optimizer#498
dangxingyu wants to merge 16 commits intokarpathy:masterfrom
dangxingyu:muonh

Conversation

@dangxingyu
Copy link
Copy Markdown

@dangxingyu dangxingyu commented Feb 4, 2026

This PR is a submission of MuonH (Hyperball, https://tinyurl.com/muonh) optimizer. We ran d24 with 12 tokens per parameter on an 8xH100 cluster. We build upon the commit 8309b83 with FP8 implementation.

Step 16704 | CORE metric: 0.2645  
Total training time: 167.91m  
Minimum validation bpb: 0.747843  

Architecture Change

  1. Parameterized RMSNorm. Following our original blog post, we replace RMSNorm with parameterized RMSNorm in mlp_norm and attn_norm.
  2. Vector multiplier before residual. We apply a learnable vector scalar to the projection outputs in Attention/MLP before residual, mimicking the zero init of projection matrices in both MLP and Attention block to keep the change minimal to baseline.

Initialization

self.attn_norm = nn.RMSNorm(config.n_embd)
self.mlp_norm = nn.RMSNorm(config.n_embd)
self.attn_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))
self.mlp_proj_scalar = nn.Parameter(torch.zeros(config.n_embd))

Forward

x = x + self.attn(self.attn_norm(x), ve, cos_sin, window_size, kv_cache) * self.attn_proj_scalar
x = x + self.mlp(self.mlp_norm(x)) * self.mlp_proj_scalar

Remark: Note that directly 0-init the gamma in MLP RMSNorm would lead to poor loss since the gamma passes relu^2, which makes output a multiple of gamma^2. This will end up in 0 gradient for gamma from the beginning and make the MLP never learn.

Optimizer Change:

  1. Fused with Normuon. In Normuon, the update is rescaled back to the norm of muon update. In hyperball, we rescale the norm of update back to weight norm. We fuse the 2 rescaling together. Also we compute and save the weight norms at the first optimizer step. We further remove cautious weight decay.
# cache weight norm once
if "p_norm" not in state:
    state["p_norm"] = stacked_params.norm(dim=(-2,-1), keepdim=True).clone()
p_norm = state["p_norm"]
# fused hyperball update core
def hyperball_step_fused(..., p_norm, ...):
    ...
    v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
    p_norm = p_norm.to(v_norm_new.dtype)
    final_scale = step_size * p_norm / v_norm_new.clamp_min(1e-10)
    g = g * final_scale.to(g.dtype)
    u = g.to(stacked_params.dtype)

    stacked_params.sub_(lr * u)
    p_new_norm = stacked_params.norm(dim=(-2, -1), keepdim=True).clamp_min(1e-10)
    stacked_params.mul_(p_norm / p_new_norm)
  1. Decoupled learning rate schedule. We decouple the lr scheduling of AdamW and MuonH. We use warmdown=1.0 for MuonH and warmdown=0.3 for AdamW. MuonH Implementation #499
  2. Depth/Chinchilla-scaling. Following Andrej’s observation that $wd \sim (1/depth)^2$ for the scaling series, we expect the effective learning rate scale as $lr_{\text{eff}} \sim \sqrt{lr*wd} \sim 1/depth$.

Here’s our run command line:

torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
  --depth=24 \
  --target-param-data-ratio=12 \
  --device-batch-size=16 \
  --total-batch-size=524288 \
  --window-pattern=SSSL \
  --matrix-optimizer=hyperball \
  --matrix-lr=0.02 \
  --warmdown-ratio=0.3 \
  --matrix-warmdown-ratio=1.0 \
  --fp8 --fp8-recipe=tensorwise \
  --run=muonh_d24_ratio12 \
  --model-tag=d24_muonh_fp8

or refer to the script runs/quickrun_muonh.sh

Here is our public wandb run for this PR wandb.ai/xingyu20/nanochat/runs/uocq5uxw. We also tried to apply MuonH on the earlier commit with bfloat16 precision with d24 TPP 11 and also observed some boost wandb.ai/xingyu20/nanochat/runs/5f40sch5:

Step 15312 | CORE metric: 0.2739
Total training time: 177.42m
Minimum validation bpb: 0.749824

Remark. in wandb.ai/xingyu20/nanochat/runs/uocq5uxw , the learning rate logged is 0.01 and this is because we didn’t implement the depth scaling at that and 0.01 is the actual learning rate 0.01 = 0.02 * base_depth / args.depth.

@karpathy
Copy link
Copy Markdown
Owner

karpathy commented Feb 4, 2026

Nice, I'll take a look today! As you may have seen in dev/LOG.md I tried MuonH earlier because I liked the idea when I stumbled by it online, but wasn't able to get it to work out of the box with a quick attempt. There's a lot of details to get right though!

@karpathy
Copy link
Copy Markdown
Owner

karpathy commented Feb 4, 2026

Are you able to get gains even for d12? It would be encouraging if we could beat baseline at e.g. d12, d16, and also at d26 (GPT-2). For example I tried like this:

torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
  --depth=12 \
  --matrix-optimizer=hyperball \
  --matrix-lr=0.02 \
  --warmdown-ratio=0.3 \
  --matrix-warmdown-ratio=1.0 \
  --fp8 --fp8-recipe=tensorwise \
  --run=muonh_d12 \
  --model-tag=d12_muonh_fp8

but the result is worse than baseline. possibly something needs more tuning. looking...

@WhenWen
Copy link
Copy Markdown

WhenWen commented Feb 4, 2026

We actually observe no significant improvement on earlier bfloat16 d12 experiments as well. We ran two baseline runs at that time and the only difference is we apply norm logging to the second one. There is a 0.002 difference between these two runs and all our final validation bpb fall into this range

We also observe the phenomenon that adding more parameters in general hurt the loss on d12. For example the run with only the vector multiplier get a lower loss than the one with all the parameters we added. We think this may be due to the fact that the training is too short on d12. Therefore, we decide to directly run the d24 experiment.

Here are our runs in the wandb report (it is ran on B200 so the MFU is not very informative).
https://api.wandb.ai/links/xingyu20/z99nmwb6

We are now trying to run the scaling on more model sizes!

@karpathy
Copy link
Copy Markdown
Owner

karpathy commented Feb 4, 2026

That's interesting, I'm seeing about 3e-3 difference on d12. One thing to be careful with is that I think the scalar_lr is I think set too high for some params (e.g. the rmsnorm params). Something to experiment with and be careful with on your end too possibly. I'm running the d26 repro now and I'll play with d12 as well to try to at least not see a regression.

@karpathy
Copy link
Copy Markdown
Owner

karpathy commented Feb 5, 2026

Ok so I am not able to reproduce this on my end for some reason. I am on your branch and ran with:

torchrun --standalone --nproc_per_node=8 -m scripts.base_train --   --depth=26   --target-param-data-ratio=8.5   --device-batch-size=16   --total-batch-size=524288   --window-pattern=SSSL   --matrix-optimizer=hyperball   --matrix-lr=0.02   --warmdown-ratio=0.3   --matrix-warmdown-ratio=1.0   --fp8 --fp8-recipe=tensorwise   --run=muonh_d26_ratio8.5   --model-tag=d26_muonh_fp8
image

Which is the fair comparison to the latest leaderboard entry of d26 with ratio 8.5. Not 100% sure what's off, looking...

@WhenWen
Copy link
Copy Markdown

WhenWen commented Feb 5, 2026

We also ran one d26 8 TPP runs! Here is the link https://wandb.ai/xingyu20/nanochat/runs/rwlqdwrr.

Step 14014 | CORE metric: 0.2646
total training time: 173.29m
Minimum validation bpb: 0.747614

Noted that this run is slower than our d24 11 TPP run in wall time.

We did a back-of-napkin calculation and it seems like the step-wise overhead is 10% for the d26 FP8 models, whereas we previously observed that for bfloat16 models the overheads are typically only 5%. The overhead likely comes from the bfloat16 RMSNorm parameter as we have benchmarked MuonH on the original architecture before and it has almost the same MFU. We are running a new version with FP8 RMSNorm parameter now and hope that we can solve the step-wise overhead.

Thank you Andrej for spending today trying MuonH! One quick question, what is the final validation bpb for the pink run in your plot?

@karpathy
Copy link
Copy Markdown
Owner

karpathy commented Feb 5, 2026

The MuonH run (pink) came down to 0.74523 for me.

Sounds good. Usually my process is:

  1. Run the baseline on master, copy pasting the thing from leaderboard. Important to have as the baseline on your actual setup.
  2. Make surgical change, run & compare. And for comparisons I look at something that looks like this
image

i.e. the val_bpb with x axes of steps, training_time_flops and flops. CORE metric. And then VRAM, MFU, tok/s

Kaiyue Wen and others added 8 commits February 12, 2026 16:15
…ng improvements

Major changes:
- Add custom FP8 training module (replaces torchao dependency)
- Implement auto-calculated optimal batch sizes (1M for d26)
- Add hyperball data scaling
- Restore and tune momentum schedule (settled on 0.95)
- Add matrix warmup ratio and norm_lr parameters
- Improve weight decay scaling (Tepoch-based theory)
- Update d26 configuration and scaling laws
- Clarify MFU labeling as bf16_mfu
- Update leaderboard and documentation

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
- Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao
- Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call
- Remove torchao dependency mention from base_train.py help text
- Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
The custom fp8 module had a performance issue in reparam_linear:
it was doing reshape→matmul→reshape on every linear layer, and
torch.compile couldn't fuse these operations because _Float8Matmul
was marked @allow_in_graph (opaque to compiler).

torchao's matmul_with_hp_or_float8_args handles N-D tensors directly
without external reshaping, allowing better fusion opportunities and
higher MFU.

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
…ayers

- reparam_linear: uses torchao for efficient N-D tensor handling without reshaping
- Float8Linear layers: uses custom fp8 module (simpler, same performance)
- This gives us the best of both: high MFU and minimal dependencies

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Added _Float8MatmulND to fp8.py:
- Handles N-D input tensors efficiently
- Does reshaping internally (opaque to torch.compile)
- Prevents external reshape overhead that was causing MFU regression
- ~75 lines of clean, documented code

Benefits:
- No torchao dependency (removed from pyproject.toml)
- Same performance as torchao for reparam_linear
- Consistent with fp8.py's minimal philosophy (~350 total lines)
- All FP8 logic in one self-contained module

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Resolved conflicts in scripts/base_train.py by keeping muonh-submit features
(hyperball optimizer support, norm_lr parameter, matrix warmup ratio) while
incorporating latest master improvements.

Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
@WhenWen
Copy link
Copy Markdown

WhenWen commented Feb 13, 2026

Over the past week, we did the following update.

  1. We did a simple implementation that fuses scalar and gamma into weight matrices in forward process, this cuts down the step-wise overhead to be less than 1%.
def reparam_linear(module, x, gamma=None, scalar=None):
    """Linear with gamma/scalar folded into weight. Works with both nn.Linear and Float8Linear.

    gamma: RMSNorm learnable weight, folded into input dim of W  (w = w * gamma[None, :])
    scalar: projection scalar, folded into output dim of W       (w = scalar[:, None] * w)

    For FP8, uses minimal custom _Float8MatmulND which handles N-D tensors internally.
    """
    w = module.weight
    if gamma is not None:
        w = w * gamma[None, :]
    if scalar is not None:
        w = scalar[:, None] * w
    # FP8 path: use custom _Float8MatmulND for efficient N-D tensor handling
    # (reshaping is done internally, so torch.compile sees it as one opaque operation)
    if Float8Linear is not None and isinstance(module, Float8Linear):
        # Handle autocast (Float8Linear expects this)
        if torch.is_autocast_enabled():
            x = x.to(torch.get_autocast_gpu_dtype())
        output = _Float8MatmulND.apply(x, w)
        if module.bias is not None:
            output = output + module.bias.to(output.dtype)
        return output
    # BF16 path
    return F.linear(x, w)
  1. We eventually remove RMSNorm gamma and only keep c_proj_scalar for each layer and use another vector called v_proj_scalar for the ve gate matrices.
  2. We did a MuP style initialization for the ve gate matrix.
  3. Following the practice in the current repo, we fitted a learning rate scaling law using d12 with 0.5M batch size. We observe that for Hyperball the learning rate scales roughly with (data size)^{-0.35}, and use this scaling rule for different depths directly.
image
  1. With the current implementation, the most performant setting of MuonH is still d24 12x TPP 0.5M. We get the following result.
Step 16704 | CORE metric: 0.2603
Total training time: 159.83m
Minimum validation bpb: 0.748849

The run is here https://wandb.ai/xingyu20/nanochat/runs/o7kxn80u

  1. We did ablations over baseline (at commit 1ec0a34, before the torchao commit). Here are the head-to-head comparison over the following four settings. We did observe a small loss increase on d26 8.25x TPP 1M and d12 10.5x TPP 1M setting (although the core metric is slightly higher). We suspect that MuonH mostly work for the larger TPP.
image image

Raw runs are here
d24 12xTPP 0.5M our https://wandb.ai/xingyu20/nanochat/runs/xdwvezqw
d24 12xTPP 0.5M baseline https://wandb.ai/xingyu20/nanochat/runs/lbim5fyf
d24 12xTPP 1M our https://wandb.ai/xingyu20/nanochat/runs/64j8b6hd
d24 12xTPP 1M baseline https://wandb.ai/xingyu20/nanochat/runs/ger9cc86
d26 8.25xTPP 1M our https://wandb.ai/xingyu20/nanochat/runs/buqjzfej
d26 8.25xTPP 1M baseline https://wandb.ai/xingyu20/nanochat/runs/b13ryott
d12 10.5xTPP 1M our https://wandb.ai/xingyu20/nanochat/runs/6yspozmk
d12 10.5xTPP 1M baseline https://wandb.ai/xingyu20/nanochat/runs/koy7j1kv

Resolved conflicts:
- nanochat/fp8.py: Kept _Float8MatmulND class from muonh
- scripts/base_train.py: Kept dual lrm logging from muonh
@kschwethelm
Copy link
Copy Markdown

Hey! Very interesting work. Wanted to flag a recent paper that validates and extends these findings: arXiv:2603.28743 (Ren et al., Microsoft, "Rethinking Language Model Scaling under Transferable Hypersphere Optimization").

I would be interested to hear what your opinion is on their findings :)

Copy link
Copy Markdown
Collaborator

@svlandeg svlandeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dangxingyu, thanks for all the work on this!

This PR has now gotten a bit stale, and there have been new leaderboard entries with better val_bpb and time. Will you be updating this PR and benchmark it against the new results? If not, we'll probably go ahead and close this one.

@svlandeg svlandeg added waiting Waiting for user feedback/action potential_improvement and removed improvement labels Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

5 participants