Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix pipeline dtype
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Sep 3, 2025
commit 3321f2afd3adcdbad0b8414cd233f945b4dfe9d7
44 changes: 22 additions & 22 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,29 +995,29 @@ def pipeline(
)
model_kwargs["device_map"] = device_map

# BC for the `torch_dtype` argument
if (torch_dtype := kwargs.get("torch_dtype")) is not None:
# BC for the `torch_dtype` argument
if (torch_dtype := kwargs.get("torch_dtype")) is not None:
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
# If both are provided, keep `dtype`
dtype = torch_dtype if dtype == "auto" else dtype
if "torch_dtype" in model_kwargs or "dtype" in model_kwargs:
if "torch_dtype" in model_kwargs:
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
# If both are provided, keep `dtype`
dtype = torch_dtype if dtype == "auto" else dtype
if "torch_dtype" in model_kwargs or "dtype" in model_kwargs:
if "torch_dtype" in model_kwargs:
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
# If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
# raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
# top-level argument keeps its default value "auto".
if dtype == "auto":
dtype = None
else:
raise ValueError(
'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if dtype is not None:
if isinstance(dtype, str) and hasattr(torch, dtype):
dtype = getattr(torch, dtype)
model_kwargs["dtype"] = dtype
# If the user did not explicitly provide `dtype` (i.e. the function default "auto" is still
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
# raising. This prevents false positives like providing `dtype` only via `model_kwargs` while the
# top-level argument keeps its default value "auto".
if dtype == "auto":
dtype = None
else:
raise ValueError(
'You cannot use both `pipeline(... dtype=..., model_kwargs={"dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if dtype is not None:
if isinstance(dtype, str) and hasattr(torch, dtype):
dtype = getattr(torch, dtype)
model_kwargs["dtype"] = dtype

model_name = model if isinstance(model, str) else None

Expand Down