mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 07:10:15 +08:00
Compare commits
10 Commits
6588d9f53f
...
0f2f0ef453
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f2f0ef453 | ||
|
|
c6238047ee | ||
|
|
38f5db0118 | ||
|
|
ea3ec049bd | ||
|
|
96803b16c0 | ||
|
|
9907a5e4f5 | ||
|
|
e3cc20034d | ||
|
|
77a46c68ea | ||
|
|
406dab2d53 | ||
|
|
ef7b4a717a |
@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
|
||||
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
|
||||
|
||||
### Instructions:
|
||||
|
||||
|
||||
@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
# MPS workaround: perform float8 conversion on CPU
|
||||
target_device = value.device
|
||||
use_cpu_staging = (target_device.type == "mps")
|
||||
|
||||
output_device = "cpu" if use_cpu_staging else target_device
|
||||
output = torch.empty_like(value, dtype=dtype, device=output_device)
|
||||
|
||||
generator = torch.Generator(device=target_device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
|
||||
if use_cpu_staging:
|
||||
res = res.cpu()
|
||||
output[i:i+slice_size].copy_(res)
|
||||
|
||||
if use_cpu_staging:
|
||||
return output.to(target_device)
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
77
comfy/mps_ops.py
Normal file
77
comfy/mps_ops.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
|
||||
_LUT_CACHE = {}
|
||||
|
||||
def get_lut(dtype, device):
|
||||
"""
|
||||
Get or create a lookup table for float8 dequantization on MPS.
|
||||
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
|
||||
"""
|
||||
key = (dtype, device)
|
||||
if key in _LUT_CACHE:
|
||||
return _LUT_CACHE[key]
|
||||
|
||||
# Generate all possible 8-bit values (0-255)
|
||||
# We create them on CPU first as float8, then cast to float16, then move to MPS.
|
||||
# This acts as our decoding table.
|
||||
|
||||
# Create uint8 pattern 0..255
|
||||
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
|
||||
|
||||
# View as the target float8 type
|
||||
# Note: We must use .view() on a tensor that has the same number of bytes.
|
||||
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
|
||||
# but we can create the float8 tensor from bytes.
|
||||
|
||||
# Actually, the easiest way to generate the LUT is:
|
||||
# 1. Create bytes 0..255
|
||||
# 2. View as float8 (on CPU, where it is supported)
|
||||
# 3. Convert to float16 (on CPU)
|
||||
# 4. Move float16 LUT to MPS
|
||||
|
||||
try:
|
||||
f8_tensor = byte_pattern.view(dtype)
|
||||
f16_lut = f8_tensor.to(torch.float16)
|
||||
|
||||
# Move to the requested MPS device
|
||||
lut = f16_lut.to(device)
|
||||
_LUT_CACHE[key] = lut
|
||||
return lut
|
||||
except Exception as e:
|
||||
print(f"Failed to create MPS LUT for {dtype}: {e}")
|
||||
# Fallback: return None or raise
|
||||
raise e
|
||||
|
||||
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
|
||||
"""
|
||||
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
|
||||
|
||||
Args:
|
||||
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
|
||||
scale: Tensor (scalar)
|
||||
orig_dtype: The target dtype (e.g. float16)
|
||||
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
|
||||
Returns:
|
||||
Tensor of shape (...) with dtype=orig_dtype
|
||||
"""
|
||||
lut = get_lut(float8_dtype, qdata.device)
|
||||
|
||||
# Use index_select or advanced indexing.
|
||||
# Advanced indexing lut[qdata.long()] is generally efficient.
|
||||
# We explicitly cast to long (int64) for indexing.
|
||||
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
|
||||
|
||||
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
|
||||
if lut.dtype != orig_dtype:
|
||||
lut = lut.to(dtype=orig_dtype)
|
||||
|
||||
output = lut[qdata.long()]
|
||||
|
||||
# Apply scale
|
||||
# Scale might need to be cast to orig_dtype too
|
||||
if isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(dtype=orig_dtype)
|
||||
|
||||
output.mul_(scale)
|
||||
return output
|
||||
@ -28,22 +28,148 @@ except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class ck_dummy:
|
||||
@staticmethod
|
||||
def quantize_per_tensor_fp8(tensor, scale, dtype):
|
||||
return (tensor / scale.to(tensor.device)).to(dtype)
|
||||
ck = ck_dummy
|
||||
|
||||
class QuantizedTensor:
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_cls = layout_type # Alias for compatibility
|
||||
self._layout_params = layout_params
|
||||
self._params = layout_params # Alias for compatibility
|
||||
self.device = qdata.device
|
||||
self.dtype = qdata.dtype
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **kwargs):
|
||||
layout_cls = get_layout_class(layout_type)
|
||||
if layout_cls is None:
|
||||
raise ValueError(f"Unknown layout type: {layout_type}")
|
||||
qdata, params = layout_cls.quantize(tensor, **kwargs)
|
||||
return cls(qdata, layout_type, params)
|
||||
|
||||
def dequantize(self):
|
||||
layout_cls = get_layout_class(self._layout_type)
|
||||
if layout_cls is None:
|
||||
return self._qdata
|
||||
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
dtype = kwargs.get("dtype", None)
|
||||
if len(args) > 0:
|
||||
if isinstance(args[0], (torch.device, str)):
|
||||
device = args[0]
|
||||
elif isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
|
||||
new_qdata = self._qdata.to(*args, **kwargs)
|
||||
new_params = self._layout_params.copy()
|
||||
if device is not None:
|
||||
for k, v in new_params.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params.__dict__[k] = v.to(device=device)
|
||||
|
||||
if dtype is not None:
|
||||
new_params.orig_dtype = dtype
|
||||
|
||||
return type(self)(new_qdata, self._layout_type, new_params)
|
||||
|
||||
def detach(self):
|
||||
return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy())
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy())
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
self._qdata.requires_grad_(requires_grad)
|
||||
return self
|
||||
|
||||
def numel(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
import math
|
||||
return math.prod(self._layout_params.orig_shape)
|
||||
return self._qdata.numel()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
return torch.Size(self._layout_params.orig_shape)
|
||||
return self._qdata.shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
def size(self, dim=None):
|
||||
if dim is None:
|
||||
return self.shape
|
||||
return self.shape[dim]
|
||||
|
||||
def dim(self):
|
||||
return self.ndim
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "params":
|
||||
return self._layout_params
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.empty_like:
|
||||
input_t = args[0]
|
||||
if isinstance(input_t, cls):
|
||||
dtype = kwargs.get("dtype", input_t.dtype)
|
||||
device = kwargs.get("device", input_t.device)
|
||||
return torch.empty(input_t.shape, dtype=dtype, device=device)
|
||||
|
||||
if func is torch.Tensor.copy_:
|
||||
dst, src = args[:2]
|
||||
if isinstance(src, cls):
|
||||
return dst.copy_(src.dequantize(), **kwargs)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return NotImplemented
|
||||
|
||||
class QuantizedLayout:
|
||||
class Params:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def copy(self):
|
||||
return type(self)(**self.__dict__)
|
||||
|
||||
class _CKFp8Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
class TensorCoreNVFP4Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY = {}
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY[name] = cls
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
return _LOCAL_LAYOUT_REGISTRY.get(name)
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
def decorator(handler_func):
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
import comfy.float
|
||||
import comfy.mps_ops
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
@ -51,7 +177,13 @@ import comfy.float
|
||||
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
@ -83,6 +215,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
if qdata.device.type == "mps":
|
||||
if qdata.dtype == torch.uint8:
|
||||
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
|
||||
elif qdata.is_floating_point() and qdata.element_size() == 1:
|
||||
# It is MPS Float8. View as uint8.
|
||||
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
|
||||
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
Loading…
Reference in New Issue
Block a user