Fix bias dtype issue in mixed ops. (#11293)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2025-12-12 08:49:35 -08:00 committed by GitHub
parent 908fd7d749
commit c5a47a1692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -504,10 +504,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
if bias: self._has_bias = bias
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -536,6 +533,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype: if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else: else:
self.quant_format = layer_conf.get("format", None) self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm: if not self._full_precision_mm:
@ -565,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False requires_grad=False
) )
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]: for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}" param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None) _v = state_dict.pop(param_key, None)