Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
get_param_name
  • Loading branch information
SunMarc committed Oct 9, 2025
commit d831ab2e544f7748e64b39e2686b535dbd77acad
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def _load_state_dict_into_meta_model(
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
param_name = hf_quantizer.update_param_name(param_name)
param_name = hf_quantizer.get_param_name(param_name)
module, param_type = get_module_from_name(model, param_name)
value = getattr(module, param_type)
# We need to wait until the quantized value is created
Expand Down Expand Up @@ -5818,7 +5818,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
if hf_quantizer is not None:
param_name = hf_quantizer.update_param_name(param_name)
param_name = hf_quantizer.get_param_name(param_name)

try:
param = model.get_parameter_or_buffer(param_name)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _dequantize(self, model):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

def update_param_name(self, param_name: str) -> str:
def get_param_name(self, param_name: str) -> str:
"""
Override this method if you want to adjust the `param_name`.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
module, name = get_module_from_name(model, param_name)
return isinstance(module, bnb.nn.Linear4bit) and name != "bias"

def update_param_name(self, param_name: str) -> str:
def get_param_name(self, param_name: str) -> str:
"""
Update param_name in order to get the module associated with the param.
Get the right param_name in order to get the module associated with the param.
This is useful for quantized stats lile absmax or quant_map as we need to update the param_name to get the module as they are stored in ...weight.absmax.
"""
if self.pre_quantized:
Expand All @@ -180,7 +180,7 @@ def create_quantized_param(
full_name = param_name

# update param name to get the weights instead of the quantized stats
param_name = self.update_param_name(param_name)
param_name = self.get_param_name(param_name)
module, tensor_name = get_module_from_name(model, param_name)

# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def update_ep_plan(self, config):
)
return config

def update_param_name(self, param_name: str) -> str:
def get_param_name(self, param_name: str) -> str:
if self.quantization_config.dequantize:
if "_blocks" in param_name:
return param_name.replace("_blocks", "")
Expand Down
12 changes: 6 additions & 6 deletions tests/quantization/mxfp4/test_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_update_expected_keys(self):

self.assertEqual(set(updated_keys), set(expected_updated))

def test_update_param_name_dequantize(self):
def test_get_param_name_dequantize(self):
"""Test parameter name updating when dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

Expand All @@ -274,28 +274,28 @@ def test_update_param_name_dequantize(self):

# Should remove _blocks suffix
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.update_param_name(param_name)
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj")

# Should remove _scales suffix
param_name = "model.layers.0.mlp.experts.down_proj_scales"
updated_name = quantizer.update_param_name(param_name)
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj")

# Should not change other names
param_name = "model.embed_tokens.weight"
updated_name = quantizer.update_param_name(param_name)
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, "model.embed_tokens.weight")

def test_update_param_name_no_dequantize(self):
def test_get_param_name_no_dequantize(self):
"""Test parameter name updating when not dequantizing"""
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer

config = Mxfp4Config(dequantize=False)
quantizer = Mxfp4HfQuantizer(config)

param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
updated_name = quantizer.update_param_name(param_name)
updated_name = quantizer.get_param_name(param_name)
self.assertEqual(updated_name, param_name)

def test_is_trainable(self):
Expand Down