Properly save mixed ops. (#11772)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2026-01-09 23:03:57 -08:00 committed by GitHub
parent 153bc524bf
commit dc202a2e51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 12 deletions

View File

@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
layout_cls = self.weight._layout_cls
if destination is not None:
sd = destination
else:
sd = {}
# Check if it's any FP8 variant (E4M3 or E5M2)
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout":
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias):

View File

@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)