Properly save mixed ops. (#11772)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
comfyanonymous 2026-01-09 23:03:57 -08:00 committed by GitHub
parent 153bc524bf
commit dc202a2e51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 12 deletions

View File

@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key) missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) if destination is not None:
if isinstance(self.weight, QuantizedTensor): sd = destination
layout_cls = self.weight._layout_cls else:
sd = {}
# Check if it's any FP8 variant (E4M3 or E5M2) if self.bias is not None:
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): sd["{}bias".format(prefix)] = self.bias
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout": if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale sd_out = self.weight.state_dict("{}weight".format(prefix))
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format} quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config: if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):

View File

@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict2 = model.state_dict() state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved # Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors # Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)