Fixup mapper issues and resolve properly#4124
Conversation
for more information, see https://pre-commit.ci
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors and enhances the model name resolution system, particularly addressing how FP8 models are handled by redirecting them to their BF16 counterparts. By centralizing this complex logic into a single, robust Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the model name resolution logic, centralizing it in get_model_name and removing the now-redundant _redirect_fp8_to_bf16 function. It also introduces a comprehensive test suite for the new mapping logic. The changes are generally well-structured and improve maintainability. However, I've identified a critical issue in the refactored get_model_name function that breaks the offline FP8 quantization path. Additionally, the new tests could be expanded to cover load_in_fp8 scenarios to ensure full coverage of the updated functionality.
I am having trouble creating individual review comments. Click here to see my feedback.
unsloth/models/loader_utils.py (265-268)
This refactoring seems to have introduced a bug. The logic in unsloth/models/loader.py (line 376) relies on get_model_name returning None to trigger offline FP8 quantization when load_in_fp8 is true and a pre-quantized model is not found.
The new implementation of get_model_name always returns a model name string, which will prevent the offline quantization path from being taken.
To fix this, you should restore the previous behavior for the load_in_fp8 case, which is to return new_model_name directly, even if it's None.
if load_in_fp8 != False:
# Handle on the fly TorchAO FP8 quantization
return new_model_name
if new_model_name is None:
new_model_name = model_name
return new_model_nametests/test_get_model_name.py (12-18)
The tests for get_model_name don't seem to cover cases where load_in_fp8 is True. Since the signature of get_model_name has been updated to include load_in_fp8, it would be beneficial to expand the test matrix to include this parameter to ensure the new logic is fully tested.
You could update _assert_mapping to accept load_in_fp8 and then update the test cases in test_resolution_matrix accordingly. This would also require updating the loop in test_resolution_matrix to handle the new parameter.
def _assert_mapping(self, model_name, load_in_4bit, load_in_fp8, expected, should_change):
mapped = get_model_name(model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8)
self.assertEqual(mapped.lower(), expected.lower())
if should_change:
self.assertNotEqual(mapped.lower(), model_name.lower())
else:
self.assertEqual(mapped.lower(), model_name.lower())
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dc7ddd8831
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if new_model_name is None: | ||
| new_model_name = model_name |
There was a problem hiding this comment.
Preserve None sentinel for unresolved FP8 mapping
This fallback changes get_model_name(..., load_in_fp8=True) from returning None (the sentinel used by the loaders to trigger _offline_quantize_to_fp8) to always returning the original model string. In both loader flows, offline FP8 quantization is gated on new_model_name is None (see unsloth/models/loader.py around the if new_model_name is None and load_in_fp8 != False branches), so this makes that path unreachable for environments without vllm>=0.12 and can leave load_in_fp8 requests unfulfilled or failing later during load.
Useful? React with 👍 / 👎.
| # Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with | ||
| # FP8 weights. We just need to update it here for sanity. | ||
| auto_config.model_name = model_name |
There was a problem hiding this comment.
Restore runtime FP8-to-BF16 redirect for unmapped models
This now assumes get_model_name already handled FP8-weight repos, but get_model_name only uses static mappers and does not inspect quantization_config or probe a -BF16 sibling. As a result, FP8 models that are not explicitly listed in mapper.py no longer get redirected before load; they continue with an FP8 config in non-FP8 modes, which can trigger unsupported-load failures that the previous _redirect_fp8_to_bf16 path avoided.
Useful? React with 👍 / 👎.
To run the test:
python -m unittest tests/test_get_model_name.pyfrom the root directory.