diff --git a/comfy/ops.py b/comfy/ops.py index 1cf22f0cc..9c0b54ff4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(key) def state_dict(self, *args, destination=None, prefix="", **kwargs): - sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) - if isinstance(self.weight, QuantizedTensor): - layout_cls = self.weight._layout_cls + if destination is not None: + sd = destination + else: + sd = {} - # Check if it's any FP8 variant (E4M3 or E5M2) - if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): - sd["{}weight_scale".format(prefix)] = self.weight._params.scale - elif layout_cls == "TensorCoreNVFP4Layout": - sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale - sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + if self.bias is not None: + sd["{}bias".format(prefix)] = self.bias + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] quant_conf = {"format": self.quant_format} if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + + input_scale = getattr(self, 'input_scale', None) + if input_scale is not None: + sd["{}input_scale".format(prefix)] = input_scale + else: + sd["{}weight".format(prefix)] = self.weight return sd def _forward(self, input, weight, bias): diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7b2eac940..7c740491d 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict2 = model.state_dict() # Verify layer1.weight is a QuantizedTensor with scale preserved - self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") + self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8))) + self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0) + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)