mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 04:37:32 +08:00
Compare commits
No commits in common. "dc202a2e51bf7a6cd00e606b2d2941bc223f2ad2" and "393d2880ddc6e57c0fa3b878bb50fa2901bd57e6" have entirely different histories.
dc202a2e51
...
393d2880dd
26
comfy/ops.py
26
comfy/ops.py
@ -625,29 +625,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
missing_keys.remove(key)
|
missing_keys.remove(key)
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
if destination is not None:
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
sd = destination
|
|
||||||
else:
|
|
||||||
sd = {}
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
|
||||||
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
layout_cls = self.weight._layout_cls
|
||||||
for k in sd_out:
|
|
||||||
sd[k] = sd_out[k]
|
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||||
|
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||||
|
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||||
|
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm_config:
|
if self._full_precision_mm_config:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
input_scale = getattr(self, 'input_scale', None)
|
|
||||||
if input_scale is not None:
|
|
||||||
sd["{}input_scale".format(prefix)] = input_scale
|
|
||||||
else:
|
|
||||||
sd["{}weight".format(prefix)] = self.weight
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.36.13
|
comfyui-frontend-package==1.36.13
|
||||||
comfyui-workflow-templates==0.7.69
|
comfyui-workflow-templates==0.7.69
|
||||||
comfyui-embedded-docs==0.4.0
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
state_dict2 = model.state_dict()
|
state_dict2 = model.state_dict()
|
||||||
|
|
||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
|
||||||
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user