Speedrun Submission: MuonH Optimizer#498
Speedrun Submission: MuonH Optimizer#498dangxingyu wants to merge 16 commits intokarpathy:masterfrom
Conversation
|
Nice, I'll take a look today! As you may have seen in dev/LOG.md I tried MuonH earlier because I liked the idea when I stumbled by it online, but wasn't able to get it to work out of the box with a quick attempt. There's a lot of details to get right though! |
|
Are you able to get gains even for d12? It would be encouraging if we could beat baseline at e.g. d12, d16, and also at d26 (GPT-2). For example I tried like this: but the result is worse than baseline. possibly something needs more tuning. looking... |
|
We actually observe no significant improvement on earlier bfloat16 d12 experiments as well. We ran two baseline runs at that time and the only difference is we apply norm logging to the second one. There is a 0.002 difference between these two runs and all our final validation bpb fall into this range We also observe the phenomenon that adding more parameters in general hurt the loss on d12. For example the run with only the vector multiplier get a lower loss than the one with all the parameters we added. We think this may be due to the fact that the training is too short on d12. Therefore, we decide to directly run the d24 experiment. Here are our runs in the wandb report (it is ran on B200 so the MFU is not very informative). We are now trying to run the scaling on more model sizes! |
|
That's interesting, I'm seeing about 3e-3 difference on d12. One thing to be careful with is that I think the scalar_lr is I think set too high for some params (e.g. the rmsnorm params). Something to experiment with and be careful with on your end too possibly. I'm running the d26 repro now and I'll play with d12 as well to try to at least not see a regression. |
|
We also ran one d26 8 TPP runs! Here is the link https://wandb.ai/xingyu20/nanochat/runs/rwlqdwrr. Noted that this run is slower than our d24 11 TPP run in wall time. We did a back-of-napkin calculation and it seems like the step-wise overhead is 10% for the d26 FP8 models, whereas we previously observed that for bfloat16 models the overheads are typically only 5%. The overhead likely comes from the bfloat16 RMSNorm parameter as we have benchmarked MuonH on the original architecture before and it has almost the same MFU. We are running a new version with FP8 RMSNorm parameter now and hope that we can solve the step-wise overhead. Thank you Andrej for spending today trying MuonH! One quick question, what is the final validation bpb for the pink run in your plot? |
…ng improvements Major changes: - Add custom FP8 training module (replaces torchao dependency) - Implement auto-calculated optimal batch sizes (1M for d26) - Add hyperball data scaling - Restore and tune momentum schedule (settled on 0.95) - Add matrix warmup ratio and norm_lr parameters - Improve weight decay scaling (Tepoch-based theory) - Update d26 configuration and scaling laws - Clarify MFU labeling as bf16_mfu - Update leaderboard and documentation Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
- Update reparam_linear to use nanochat.fp8.Float8Linear instead of torchao - Replace matmul_with_hp_or_float8_args with direct _Float8Matmul.apply call - Remove torchao dependency mention from base_train.py help text - Functionally equivalent: both use torch._scaled_mm, custom version ~3% faster Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
The custom fp8 module had a performance issue in reparam_linear: it was doing reshape→matmul→reshape on every linear layer, and torch.compile couldn't fuse these operations because _Float8Matmul was marked @allow_in_graph (opaque to compiler). torchao's matmul_with_hp_or_float8_args handles N-D tensors directly without external reshaping, allowing better fusion opportunities and higher MFU. Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
…ayers - reparam_linear: uses torchao for efficient N-D tensor handling without reshaping - Float8Linear layers: uses custom fp8 module (simpler, same performance) - This gives us the best of both: high MFU and minimal dependencies Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Added _Float8MatmulND to fp8.py: - Handles N-D input tensors efficiently - Does reshaping internally (opaque to torch.compile) - Prevents external reshape overhead that was causing MFU regression - ~75 lines of clean, documented code Benefits: - No torchao dependency (removed from pyproject.toml) - Same performance as torchao for reparam_linear - Consistent with fp8.py's minimal philosophy (~350 total lines) - All FP8 logic in one self-contained module Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
Resolved conflicts in scripts/base_train.py by keeping muonh-submit features (hyperball optimizer support, norm_lr parameter, matrix warmup ratio) while incorporating latest master improvements. Co-Authored-By: Claude Sonnet 4.5 (1M context) <noreply@anthropic.com>
|
Over the past week, we did the following update.
The run is here https://wandb.ai/xingyu20/nanochat/runs/o7kxn80u
Raw runs are here |
Resolved conflicts: - nanochat/fp8.py: Kept _Float8MatmulND class from muonh - scripts/base_train.py: Kept dual lrm logging from muonh
|
Hey! Very interesting work. Wanted to flag a recent paper that validates and extends these findings: arXiv:2603.28743 (Ren et al., Microsoft, "Rethinking Language Model Scaling under Transferable Hypersphere Optimization"). I would be interested to hear what your opinion is on their findings :) |
svlandeg
left a comment
There was a problem hiding this comment.
Hi @dangxingyu, thanks for all the work on this!
This PR has now gotten a bit stale, and there have been new leaderboard entries with better val_bpb and time. Will you be updating this PR and benchmark it against the new results? If not, we'll probably go ahead and close this one.





This PR is a submission of MuonH (Hyperball, https://tinyurl.com/muonh) optimizer. We ran
d24 with 12 tokens per parameteron an 8xH100 cluster. We build upon the commit8309b83with FP8 implementation.Architecture Change
mlp_normandattn_norm.Initialization
Forward
Remark: Note that directly 0-init the gamma in MLP RMSNorm would lead to poor loss since the gamma passes relu^2, which makes output a multiple of gamma^2. This will end up in 0 gradient for gamma from the beginning and make the MLP never learn.
Optimizer Change:
warmdown=1.0for MuonH andwarmdown=0.3for AdamW. MuonH Implementation #499Here’s our run command line:
or refer to the script
runs/quickrun_muonh.shHere is our public wandb run for this PR wandb.ai/xingyu20/nanochat/runs/uocq5uxw. We also tried to apply MuonH on the earlier commit with bfloat16 precision with
d24 TPP 11and also observed some boost wandb.ai/xingyu20/nanochat/runs/5f40sch5:Remark. in wandb.ai/xingyu20/nanochat/runs/uocq5uxw , the learning rate logged is 0.01 and this is because we didn’t implement the depth scaling at that and 0.01 is the actual learning rate
0.01 = 0.02 * base_depth / args.depth.