Compare commits

...

10 Commits

Author SHA1 Message Date
Macpaul Lin
0f2f0ef453
Merge 38f5db0118 into c6238047ee 2026-01-12 13:24:47 +05:00
comfyanonymous
c6238047ee
Put more details about portable in readme. (#11816)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-11 21:11:53 -05:00
Macpaul Lin
38f5db0118 fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ea3ec049bd fix(quant_ops): implement __torch_function__ to support torch.empty_like for mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
96803b16c0 fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
9907a5e4f5 fix(quant_ops): add numel, size, shape, dim, and ndim to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
e3cc20034d fix(quant_ops): add _layout_cls and _params aliases to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
77a46c68ea fix(quant_ops): add detach, clone, and requires_grad_ to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
406dab2d53 fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ef7b4a717a feat(mps): implement native-like Float8 support via LUT dequantization
Add a new MPS-specific operations module to handle Float8 tensor support
on Apple Silicon. Since MPS does not natively support Float8 dtypes, this
implementation uses a uint8 storage strategy combined with a GPU-accelerated
Lookup Table (LUT) for efficient dequantization, keeping data on the GPU.

- Add comfy/mps_ops.py: Implement cached LUT generation and index-based
  dequantization for MPS.
- Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8
  when moving to MPS, and route dequantization to mps_ops.
- Modify comfy/float.py: Add CPU staging for stochastic rounding to
  prevent MPS casting errors during quantization.
- Modify comfy/quant_ops.py: Add fallback for fp8_linear.

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
4 changed files with 246 additions and 11 deletions

View File

@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions:

View File

@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
# MPS workaround: perform float8 conversion on CPU
target_device = value.device
use_cpu_staging = (target_device.type == "mps")
output_device = "cpu" if use_cpu_staging else target_device
output = torch.empty_like(value, dtype=dtype, device=output_device)
generator = torch.Generator(device=target_device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
if use_cpu_staging:
res = res.cpu()
output[i:i+slice_size].copy_(res)
if use_cpu_staging:
return output.to(target_device)
return output
return value.to(dtype=dtype)

77
comfy/mps_ops.py Normal file
View File

@ -0,0 +1,77 @@
import torch
_LUT_CACHE = {}
def get_lut(dtype, device):
"""
Get or create a lookup table for float8 dequantization on MPS.
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
"""
key = (dtype, device)
if key in _LUT_CACHE:
return _LUT_CACHE[key]
# Generate all possible 8-bit values (0-255)
# We create them on CPU first as float8, then cast to float16, then move to MPS.
# This acts as our decoding table.
# Create uint8 pattern 0..255
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
# View as the target float8 type
# Note: We must use .view() on a tensor that has the same number of bytes.
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
# but we can create the float8 tensor from bytes.
# Actually, the easiest way to generate the LUT is:
# 1. Create bytes 0..255
# 2. View as float8 (on CPU, where it is supported)
# 3. Convert to float16 (on CPU)
# 4. Move float16 LUT to MPS
try:
f8_tensor = byte_pattern.view(dtype)
f16_lut = f8_tensor.to(torch.float16)
# Move to the requested MPS device
lut = f16_lut.to(device)
_LUT_CACHE[key] = lut
return lut
except Exception as e:
print(f"Failed to create MPS LUT for {dtype}: {e}")
# Fallback: return None or raise
raise e
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
"""
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
Args:
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
scale: Tensor (scalar)
orig_dtype: The target dtype (e.g. float16)
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
Returns:
Tensor of shape (...) with dtype=orig_dtype
"""
lut = get_lut(float8_dtype, qdata.device)
# Use index_select or advanced indexing.
# Advanced indexing lut[qdata.long()] is generally efficient.
# We explicitly cast to long (int64) for indexing.
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
if lut.dtype != orig_dtype:
lut = lut.to(dtype=orig_dtype)
output = lut[qdata.long()]
# Apply scale
# Scale might need to be cast to orig_dtype too
if isinstance(scale, torch.Tensor):
scale = scale.to(dtype=orig_dtype)
output.mul_(scale)
return output

View File

@ -28,22 +28,148 @@ except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class ck_dummy:
@staticmethod
def quantize_per_tensor_fp8(tensor, scale, dtype):
return (tensor / scale.to(tensor.device)).to(dtype)
ck = ck_dummy
class QuantizedTensor:
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_cls = layout_type # Alias for compatibility
self._layout_params = layout_params
self._params = layout_params # Alias for compatibility
self.device = qdata.device
self.dtype = qdata.dtype
@classmethod
def from_float(cls, tensor, layout_type, **kwargs):
layout_cls = get_layout_class(layout_type)
if layout_cls is None:
raise ValueError(f"Unknown layout type: {layout_type}")
qdata, params = layout_cls.quantize(tensor, **kwargs)
return cls(qdata, layout_type, params)
def dequantize(self):
layout_cls = get_layout_class(self._layout_type)
if layout_cls is None:
return self._qdata
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
def to(self, *args, **kwargs):
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
if len(args) > 0:
if isinstance(args[0], (torch.device, str)):
device = args[0]
elif isinstance(args[0], torch.dtype):
dtype = args[0]
new_qdata = self._qdata.to(*args, **kwargs)
new_params = self._layout_params.copy()
if device is not None:
for k, v in new_params.__dict__.items():
if isinstance(v, torch.Tensor):
new_params.__dict__[k] = v.to(device=device)
if dtype is not None:
new_params.orig_dtype = dtype
return type(self)(new_qdata, self._layout_type, new_params)
def detach(self):
return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy())
def clone(self):
return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy())
def requires_grad_(self, requires_grad=True):
self._qdata.requires_grad_(requires_grad)
return self
def numel(self):
if hasattr(self._layout_params, "orig_shape"):
import math
return math.prod(self._layout_params.orig_shape)
return self._qdata.numel()
@property
def shape(self):
if hasattr(self._layout_params, "orig_shape"):
return torch.Size(self._layout_params.orig_shape)
return self._qdata.shape
@property
def ndim(self):
return len(self.shape)
def size(self, dim=None):
if dim is None:
return self.shape
return self.shape[dim]
def dim(self):
return self.ndim
def __getattr__(self, name):
if name == "params":
return self._layout_params
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.empty_like:
input_t = args[0]
if isinstance(input_t, cls):
dtype = kwargs.get("dtype", input_t.dtype)
device = kwargs.get("device", input_t.device)
return torch.empty(input_t.shape, dtype=dtype, device=device)
if func is torch.Tensor.copy_:
dst, src = args[:2]
if isinstance(src, cls):
return dst.copy_(src.dequantize(), **kwargs)
return NotImplemented
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return NotImplemented
class QuantizedLayout:
class Params:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def copy(self):
return type(self)(**self.__dict__)
class _CKFp8Layout(QuantizedLayout):
pass
class _CKFp8Layout:
class TensorCoreNVFP4Layout(QuantizedLayout):
pass
class TensorCoreNVFP4Layout:
pass
_LOCAL_LAYOUT_REGISTRY = {}
def register_layout_class(name, cls):
pass
_LOCAL_LAYOUT_REGISTRY[name] = cls
def get_layout_class(name):
return None
return _LOCAL_LAYOUT_REGISTRY.get(name)
def register_layout_op(torch_op, layout_type):
def decorator(handler_func):
return handler_func
return decorator
import comfy.float
import comfy.mps_ops
# ==============================================================================
# FP8 Layouts with Comfy-Specific Extensions
@ -51,7 +177,13 @@ import comfy.float
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
@ -83,6 +215,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
if qdata.device.type == "mps":
if qdata.dtype == torch.uint8:
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
elif qdata.is_floating_point() and qdata.element_size() == 1:
# It is MPS Float8. View as uint8.
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn