Respect the dtype the op was initialized in for non quant mixed op. (#11282)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2025-12-11 20:32:27 -08:00 committed by GitHub
parent 982876d59a
commit 5495589db3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -497,8 +497,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
@ -530,7 +532,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm: