Compare commits

...

11 Commits

Author SHA1 Message Date
Macpaul Lin
c38ec4de6d
Merge 38f5db0118 into dc202a2e51 2026-01-10 09:29:25 +02:00
comfyanonymous
dc202a2e51
Properly save mixed ops. (#11772)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-10 02:03:57 -05:00
ComfyUI Wiki
153bc524bf
chore: update embedded docs to v0.4.0 (#11776) 2026-01-10 01:29:30 -05:00
Macpaul Lin
38f5db0118 fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ea3ec049bd fix(quant_ops): implement __torch_function__ to support torch.empty_like for mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
96803b16c0 fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
9907a5e4f5 fix(quant_ops): add numel, size, shape, dim, and ndim to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
e3cc20034d fix(quant_ops): add _layout_cls and _params aliases to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
77a46c68ea fix(quant_ops): add detach, clone, and requires_grad_ to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
406dab2d53 fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ef7b4a717a feat(mps): implement native-like Float8 support via LUT dequantization
Add a new MPS-specific operations module to handle Float8 tensor support
on Apple Silicon. Since MPS does not natively support Float8 dtypes, this
implementation uses a uint8 storage strategy combined with a GPU-accelerated
Lookup Table (LUT) for efficient dequantization, keeping data on the GPU.

- Add comfy/mps_ops.py: Implement cached LUT generation and index-based
  dequantization for MPS.
- Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8
  when moving to MPS, and route dequantization to mps_ops.
- Modify comfy/float.py: Add CPU staging for stochastic rounding to
  prevent MPS casting errors during quantization.
- Modify comfy/quant_ops.py: Add fallback for fp8_linear.

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
6 changed files with 265 additions and 22 deletions

View File

@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
# MPS workaround: perform float8 conversion on CPU
target_device = value.device
use_cpu_staging = (target_device.type == "mps")
output_device = "cpu" if use_cpu_staging else target_device
output = torch.empty_like(value, dtype=dtype, device=output_device)
generator = torch.Generator(device=target_device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
if use_cpu_staging:
res = res.cpu()
output[i:i+slice_size].copy_(res)
if use_cpu_staging:
return output.to(target_device)
return output
return value.to(dtype=dtype)

77
comfy/mps_ops.py Normal file
View File

@ -0,0 +1,77 @@
import torch
_LUT_CACHE = {}
def get_lut(dtype, device):
"""
Get or create a lookup table for float8 dequantization on MPS.
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
"""
key = (dtype, device)
if key in _LUT_CACHE:
return _LUT_CACHE[key]
# Generate all possible 8-bit values (0-255)
# We create them on CPU first as float8, then cast to float16, then move to MPS.
# This acts as our decoding table.
# Create uint8 pattern 0..255
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
# View as the target float8 type
# Note: We must use .view() on a tensor that has the same number of bytes.
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
# but we can create the float8 tensor from bytes.
# Actually, the easiest way to generate the LUT is:
# 1. Create bytes 0..255
# 2. View as float8 (on CPU, where it is supported)
# 3. Convert to float16 (on CPU)
# 4. Move float16 LUT to MPS
try:
f8_tensor = byte_pattern.view(dtype)
f16_lut = f8_tensor.to(torch.float16)
# Move to the requested MPS device
lut = f16_lut.to(device)
_LUT_CACHE[key] = lut
return lut
except Exception as e:
print(f"Failed to create MPS LUT for {dtype}: {e}")
# Fallback: return None or raise
raise e
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
"""
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
Args:
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
scale: Tensor (scalar)
orig_dtype: The target dtype (e.g. float16)
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
Returns:
Tensor of shape (...) with dtype=orig_dtype
"""
lut = get_lut(float8_dtype, qdata.device)
# Use index_select or advanced indexing.
# Advanced indexing lut[qdata.long()] is generally efficient.
# We explicitly cast to long (int64) for indexing.
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
if lut.dtype != orig_dtype:
lut = lut.to(dtype=orig_dtype)
output = lut[qdata.long()]
# Apply scale
# Scale might need to be cast to orig_dtype too
if isinstance(scale, torch.Tensor):
scale = scale.to(dtype=orig_dtype)
output.mul_(scale)
return output

View File

@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
layout_cls = self.weight._layout_cls
if destination is not None:
sd = destination
else:
sd = {}
# Check if it's any FP8 variant (E4M3 or E5M2)
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout":
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias):

View File

@ -28,22 +28,148 @@ except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class ck_dummy:
@staticmethod
def quantize_per_tensor_fp8(tensor, scale, dtype):
return (tensor / scale.to(tensor.device)).to(dtype)
ck = ck_dummy
class QuantizedTensor:
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_cls = layout_type # Alias for compatibility
self._layout_params = layout_params
self._params = layout_params # Alias for compatibility
self.device = qdata.device
self.dtype = qdata.dtype
@classmethod
def from_float(cls, tensor, layout_type, **kwargs):
layout_cls = get_layout_class(layout_type)
if layout_cls is None:
raise ValueError(f"Unknown layout type: {layout_type}")
qdata, params = layout_cls.quantize(tensor, **kwargs)
return cls(qdata, layout_type, params)
def dequantize(self):
layout_cls = get_layout_class(self._layout_type)
if layout_cls is None:
return self._qdata
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
def to(self, *args, **kwargs):
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
if len(args) > 0:
if isinstance(args[0], (torch.device, str)):
device = args[0]
elif isinstance(args[0], torch.dtype):
dtype = args[0]
new_qdata = self._qdata.to(*args, **kwargs)
new_params = self._layout_params.copy()
if device is not None:
for k, v in new_params.__dict__.items():
if isinstance(v, torch.Tensor):
new_params.__dict__[k] = v.to(device=device)
if dtype is not None:
new_params.orig_dtype = dtype
return type(self)(new_qdata, self._layout_type, new_params)
def detach(self):
return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy())
def clone(self):
return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy())
def requires_grad_(self, requires_grad=True):
self._qdata.requires_grad_(requires_grad)
return self
def numel(self):
if hasattr(self._layout_params, "orig_shape"):
import math
return math.prod(self._layout_params.orig_shape)
return self._qdata.numel()
@property
def shape(self):
if hasattr(self._layout_params, "orig_shape"):
return torch.Size(self._layout_params.orig_shape)
return self._qdata.shape
@property
def ndim(self):
return len(self.shape)
def size(self, dim=None):
if dim is None:
return self.shape
return self.shape[dim]
def dim(self):
return self.ndim
def __getattr__(self, name):
if name == "params":
return self._layout_params
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.empty_like:
input_t = args[0]
if isinstance(input_t, cls):
dtype = kwargs.get("dtype", input_t.dtype)
device = kwargs.get("device", input_t.device)
return torch.empty(input_t.shape, dtype=dtype, device=device)
if func is torch.Tensor.copy_:
dst, src = args[:2]
if isinstance(src, cls):
return dst.copy_(src.dequantize(), **kwargs)
return NotImplemented
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return NotImplemented
class QuantizedLayout:
class Params:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def copy(self):
return type(self)(**self.__dict__)
class _CKFp8Layout(QuantizedLayout):
pass
class _CKFp8Layout:
class TensorCoreNVFP4Layout(QuantizedLayout):
pass
class TensorCoreNVFP4Layout:
pass
_LOCAL_LAYOUT_REGISTRY = {}
def register_layout_class(name, cls):
pass
_LOCAL_LAYOUT_REGISTRY[name] = cls
def get_layout_class(name):
return None
return _LOCAL_LAYOUT_REGISTRY.get(name)
def register_layout_op(torch_op, layout_type):
def decorator(handler_func):
return handler_func
return decorator
import comfy.float
import comfy.mps_ops
# ==============================================================================
# FP8 Layouts with Comfy-Specific Extensions
@ -51,7 +177,13 @@ import comfy.float
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
@ -83,6 +215,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
if qdata.device.type == "mps":
if qdata.dtype == torch.uint8:
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
elif qdata.is_floating_point() and qdata.element_size() == 1:
# It is MPS Float8. View as uint8.
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.36.13
comfyui-workflow-templates==0.7.69
comfyui-embedded-docs==0.3.1
comfyui-embedded-docs==0.4.0
torch
torchsde
torchvision

View File

@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)