From ef7b4a717a891c7a9b5f343f8cc851eb46192d01 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Wed, 7 Jan 2026 01:31:21 +0800 Subject: [PATCH 1/8] feat(mps): implement native-like Float8 support via LUT dequantization Add a new MPS-specific operations module to handle Float8 tensor support on Apple Silicon. Since MPS does not natively support Float8 dtypes, this implementation uses a uint8 storage strategy combined with a GPU-accelerated Lookup Table (LUT) for efficient dequantization, keeping data on the GPU. - Add comfy/mps_ops.py: Implement cached LUT generation and index-based dequantization for MPS. - Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8 when moving to MPS, and route dequantization to mps_ops. - Modify comfy/float.py: Add CPU staging for stochastic rounding to prevent MPS casting errors during quantization. - Modify comfy/quant_ops.py: Add fallback for fp8_linear. Signed-off-by: Macpaul Lin --- comfy/float.py | 19 ++++++++++-- comfy/mps_ops.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ comfy/quant_ops.py | 22 ++++++++++++- 3 files changed, 114 insertions(+), 4 deletions(-) create mode 100644 comfy/mps_ops.py diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..848c6ff68 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: - generator = torch.Generator(device=value.device) + # MPS workaround: perform float8 conversion on CPU + target_device = value.device + use_cpu_staging = (target_device.type == "mps") + + output_device = "cpu" if use_cpu_staging else target_device + output = torch.empty_like(value, dtype=dtype, device=output_device) + + generator = torch.Generator(device=target_device) generator.manual_seed(seed) - output = torch.empty_like(value, dtype=dtype) + num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) for i in range(0, value.shape[0], slice_size): - output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) + res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator) + if use_cpu_staging: + res = res.cpu() + output[i:i+slice_size].copy_(res) + + if use_cpu_staging: + return output.to(target_device) return output return value.to(dtype=dtype) diff --git a/comfy/mps_ops.py b/comfy/mps_ops.py new file mode 100644 index 000000000..a5930356b --- /dev/null +++ b/comfy/mps_ops.py @@ -0,0 +1,77 @@ +import torch + +_LUT_CACHE = {} + +def get_lut(dtype, device): + """ + Get or create a lookup table for float8 dequantization on MPS. + Returns a Tensor[256] of dtype=torch.float16 on the specified device. + """ + key = (dtype, device) + if key in _LUT_CACHE: + return _LUT_CACHE[key] + + # Generate all possible 8-bit values (0-255) + # We create them on CPU first as float8, then cast to float16, then move to MPS. + # This acts as our decoding table. + + # Create uint8 pattern 0..255 + byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu") + + # View as the target float8 type + # Note: We must use .view() on a tensor that has the same number of bytes. + # We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily, + # but we can create the float8 tensor from bytes. + + # Actually, the easiest way to generate the LUT is: + # 1. Create bytes 0..255 + # 2. View as float8 (on CPU, where it is supported) + # 3. Convert to float16 (on CPU) + # 4. Move float16 LUT to MPS + + try: + f8_tensor = byte_pattern.view(dtype) + f16_lut = f8_tensor.to(torch.float16) + + # Move to the requested MPS device + lut = f16_lut.to(device) + _LUT_CACHE[key] = lut + return lut + except Exception as e: + print(f"Failed to create MPS LUT for {dtype}: {e}") + # Fallback: return None or raise + raise e + +def mps_dequantize(qdata, scale, orig_dtype, float8_dtype): + """ + Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS. + + Args: + qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS) + scale: Tensor (scalar) + orig_dtype: The target dtype (e.g. float16) + float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) + + Returns: + Tensor of shape (...) with dtype=orig_dtype + """ + lut = get_lut(float8_dtype, qdata.device) + + # Use index_select or advanced indexing. + # Advanced indexing lut[qdata.long()] is generally efficient. + # We explicitly cast to long (int64) for indexing. + # Note: Flattening might be slightly faster depending on shape, but simple indexing is safest. + + # We want the LUT to be in the target orig_dtype (likely float16 or bfloat16) + if lut.dtype != orig_dtype: + lut = lut.to(dtype=orig_dtype) + + output = lut[qdata.long()] + + # Apply scale + # Scale might need to be cast to orig_dtype too + if isinstance(scale, torch.Tensor): + scale = scale.to(dtype=orig_dtype) + + output.mul_(scale) + return output diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a..08a8a996d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -44,6 +44,7 @@ except ImportError as e: return None import comfy.float +import comfy.mps_ops # ============================================================================== # FP8 Layouts with Comfy-Specific Extensions @@ -51,7 +52,13 @@ import comfy.float class _TensorCoreFP8LayoutBase(_CKFp8Layout): FP8_DTYPE = None # Must be overridden in subclass - + + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): if cls.FP8_DTYPE is None: @@ -83,6 +90,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) return qdata, params + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + if qdata.device.type == "mps": + if qdata.dtype == torch.uint8: + return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn)) + elif qdata.is_floating_point() and qdata.element_size() == 1: + # It is MPS Float8. View as uint8. + return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype) + + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + plain_tensor.mul_(scale) + return plain_tensor + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn From 406dab2d535f0a6e462b537145a18973344a9587 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:46:17 +0800 Subject: [PATCH 2/8] fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 75 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 08a8a996d..545dffb30 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -28,20 +28,85 @@ except ImportError as e: logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") _CK_AVAILABLE = False + class ck_dummy: + @staticmethod + def quantize_per_tensor_fp8(tensor, scale, dtype): + return (tensor / scale.to(tensor.device)).to(dtype) + ck = ck_dummy + class QuantizedTensor: + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata + self._layout_type = layout_type + self._layout_params = layout_params + self.device = qdata.device + self.dtype = qdata.dtype + + @classmethod + def from_float(cls, tensor, layout_type, **kwargs): + layout_cls = get_layout_class(layout_type) + if layout_cls is None: + raise ValueError(f"Unknown layout type: {layout_type}") + qdata, params = layout_cls.quantize(tensor, **kwargs) + return cls(qdata, layout_type, params) + + def dequantize(self): + layout_cls = get_layout_class(self._layout_type) + if layout_cls is None: + return self._qdata + return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__) + + def to(self, *args, **kwargs): + device = kwargs.get("device", None) + if device is None and len(args) > 0: + if isinstance(args[0], (torch.device, str)): + device = args[0] + + new_qdata = self._qdata.to(*args, **kwargs) + new_params = self._layout_params.copy() + if device is not None: + for k, v in new_params.__dict__.items(): + if isinstance(v, torch.Tensor): + new_params.__dict__[k] = v.to(device=device) + + return type(self)(new_qdata, self._layout_type, new_params) + + def __getattr__(self, name): + if name == "params": + return self._layout_params + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return NotImplemented + + class QuantizedLayout: + class Params: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def copy(self): + return type(self)(**self.__dict__) + + class _CKFp8Layout(QuantizedLayout): pass - class _CKFp8Layout: + class TensorCoreNVFP4Layout(QuantizedLayout): pass - class TensorCoreNVFP4Layout: - pass + _LOCAL_LAYOUT_REGISTRY = {} def register_layout_class(name, cls): - pass + _LOCAL_LAYOUT_REGISTRY[name] = cls def get_layout_class(name): - return None + return _LOCAL_LAYOUT_REGISTRY.get(name) + + def register_layout_op(torch_op, layout_type): + def decorator(handler_func): + return handler_func + return decorator + import comfy.float import comfy.mps_ops From 77a46c68ea331fd36e98179c7b261d5d5cfa3a14 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:47:42 +0800 Subject: [PATCH 3/8] fix(quant_ops): add detach, clone, and requires_grad_ to mock QuantizedTensor Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 545dffb30..654f5870b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -71,6 +71,16 @@ except ImportError as e: return type(self)(new_qdata, self._layout_type, new_params) + def detach(self): + return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy()) + + def clone(self): + return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy()) + + def requires_grad_(self, requires_grad=True): + self._qdata.requires_grad_(requires_grad) + return self + def __getattr__(self, name): if name == "params": return self._layout_params From e3cc20034decf696d291ffd080d860707a131906 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:50:42 +0800 Subject: [PATCH 4/8] fix(quant_ops): add _layout_cls and _params aliases to mock QuantizedTensor Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 654f5870b..1e0cc0304 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -38,7 +38,9 @@ except ImportError as e: def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata self._layout_type = layout_type + self._layout_cls = layout_type # Alias for compatibility self._layout_params = layout_params + self._params = layout_params # Alias for compatibility self.device = qdata.device self.dtype = qdata.dtype From 9907a5e4f5b807a910b0969240aa2646d45395c4 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:52:25 +0800 Subject: [PATCH 5/8] fix(quant_ops): add numel, size, shape, dim, and ndim to mock QuantizedTensor Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1e0cc0304..e80e6bcdc 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -83,6 +83,30 @@ except ImportError as e: self._qdata.requires_grad_(requires_grad) return self + def numel(self): + if hasattr(self._layout_params, "orig_shape"): + import math + return math.prod(self._layout_params.orig_shape) + return self._qdata.numel() + + @property + def shape(self): + if hasattr(self._layout_params, "orig_shape"): + return torch.Size(self._layout_params.orig_shape) + return self._qdata.shape + + @property + def ndim(self): + return len(self.shape) + + def size(self, dim=None): + if dim is None: + return self.shape + return self.shape[dim] + + def dim(self): + return self.ndim + def __getattr__(self, name): if name == "params": return self._layout_params From 96803b16c0b3d8a9e781015e48c8e682d0679cba Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:56:03 +0800 Subject: [PATCH 6/8] fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index e80e6bcdc..8f282ee5b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -60,9 +60,12 @@ except ImportError as e: def to(self, *args, **kwargs): device = kwargs.get("device", None) - if device is None and len(args) > 0: + dtype = kwargs.get("dtype", None) + if len(args) > 0: if isinstance(args[0], (torch.device, str)): device = args[0] + elif isinstance(args[0], torch.dtype): + dtype = args[0] new_qdata = self._qdata.to(*args, **kwargs) new_params = self._layout_params.copy() @@ -71,6 +74,9 @@ except ImportError as e: if isinstance(v, torch.Tensor): new_params.__dict__[k] = v.to(device=device) + if dtype is not None: + new_params.orig_dtype = dtype + return type(self)(new_qdata, self._layout_type, new_params) def detach(self): From ea3ec049bdd92ad7197be6008f5444968cb9c793 Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:58:36 +0800 Subject: [PATCH 7/8] fix(quant_ops): implement __torch_function__ to support torch.empty_like for mock QuantizedTensor Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8f282ee5b..96fd211d5 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -118,6 +118,18 @@ except ImportError as e: return self._layout_params raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.empty_like: + input_t = args[0] + if isinstance(input_t, cls): + dtype = kwargs.get("dtype", input_t.dtype) + device = kwargs.get("device", input_t.device) + return torch.empty(input_t.shape, dtype=dtype, device=device) + return NotImplemented + def __torch_dispatch__(self, func, types, args=(), kwargs=None): return NotImplemented From 38f5db0118811c4bfaaa5e43630bf60bd3a7fa1c Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:59:48 +0800 Subject: [PATCH 8/8] fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 96fd211d5..b6bce4db1 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -128,6 +128,12 @@ except ImportError as e: dtype = kwargs.get("dtype", input_t.dtype) device = kwargs.get("device", input_t.device) return torch.empty(input_t.shape, dtype=dtype, device=device) + + if func is torch.Tensor.copy_: + dst, src = args[:2] + if isinstance(src, cls): + return dst.copy_(src.dequantize(), **kwargs) + return NotImplemented def __torch_dispatch__(self, func, types, args=(), kwargs=None):