mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Compare commits
11 Commits
3de2e71a79
...
bb98e5a5ea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb98e5a5ea | ||
|
|
ec0a832acb | ||
|
|
04c49a29b4 | ||
|
|
38f5db0118 | ||
|
|
ea3ec049bd | ||
|
|
96803b16c0 | ||
|
|
9907a5e4f5 | ||
|
|
e3cc20034d | ||
|
|
77a46c68ea | ||
|
|
406dab2d53 | ||
|
|
ef7b4a717a |
@ -81,7 +81,8 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
# MPS workaround: perform float8 conversion on CPU
|
||||
target_device = value.device
|
||||
use_cpu_staging = (target_device.type == "mps")
|
||||
|
||||
output_device = "cpu" if use_cpu_staging else target_device
|
||||
output = torch.empty_like(value, dtype=dtype, device=output_device)
|
||||
|
||||
generator = torch.Generator(device=target_device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
|
||||
if use_cpu_staging:
|
||||
res = res.cpu()
|
||||
output[i:i+slice_size].copy_(res)
|
||||
|
||||
if use_cpu_staging:
|
||||
return output.to(target_device)
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
77
comfy/mps_ops.py
Normal file
77
comfy/mps_ops.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
|
||||
_LUT_CACHE = {}
|
||||
|
||||
def get_lut(dtype, device):
|
||||
"""
|
||||
Get or create a lookup table for float8 dequantization on MPS.
|
||||
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
|
||||
"""
|
||||
key = (dtype, device)
|
||||
if key in _LUT_CACHE:
|
||||
return _LUT_CACHE[key]
|
||||
|
||||
# Generate all possible 8-bit values (0-255)
|
||||
# We create them on CPU first as float8, then cast to float16, then move to MPS.
|
||||
# This acts as our decoding table.
|
||||
|
||||
# Create uint8 pattern 0..255
|
||||
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
|
||||
|
||||
# View as the target float8 type
|
||||
# Note: We must use .view() on a tensor that has the same number of bytes.
|
||||
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
|
||||
# but we can create the float8 tensor from bytes.
|
||||
|
||||
# Actually, the easiest way to generate the LUT is:
|
||||
# 1. Create bytes 0..255
|
||||
# 2. View as float8 (on CPU, where it is supported)
|
||||
# 3. Convert to float16 (on CPU)
|
||||
# 4. Move float16 LUT to MPS
|
||||
|
||||
try:
|
||||
f8_tensor = byte_pattern.view(dtype)
|
||||
f16_lut = f8_tensor.to(torch.float16)
|
||||
|
||||
# Move to the requested MPS device
|
||||
lut = f16_lut.to(device)
|
||||
_LUT_CACHE[key] = lut
|
||||
return lut
|
||||
except Exception as e:
|
||||
print(f"Failed to create MPS LUT for {dtype}: {e}")
|
||||
# Fallback: return None or raise
|
||||
raise e
|
||||
|
||||
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
|
||||
"""
|
||||
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
|
||||
|
||||
Args:
|
||||
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
|
||||
scale: Tensor (scalar)
|
||||
orig_dtype: The target dtype (e.g. float16)
|
||||
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
|
||||
Returns:
|
||||
Tensor of shape (...) with dtype=orig_dtype
|
||||
"""
|
||||
lut = get_lut(float8_dtype, qdata.device)
|
||||
|
||||
# Use index_select or advanced indexing.
|
||||
# Advanced indexing lut[qdata.long()] is generally efficient.
|
||||
# We explicitly cast to long (int64) for indexing.
|
||||
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
|
||||
|
||||
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
|
||||
if lut.dtype != orig_dtype:
|
||||
lut = lut.to(dtype=orig_dtype)
|
||||
|
||||
output = lut[qdata.long()]
|
||||
|
||||
# Apply scale
|
||||
# Scale might need to be cast to orig_dtype too
|
||||
if isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(dtype=orig_dtype)
|
||||
|
||||
output.mul_(scale)
|
||||
return output
|
||||
@ -28,22 +28,148 @@ except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class ck_dummy:
|
||||
@staticmethod
|
||||
def quantize_per_tensor_fp8(tensor, scale, dtype):
|
||||
return (tensor / scale.to(tensor.device)).to(dtype)
|
||||
ck = ck_dummy
|
||||
|
||||
class QuantizedTensor:
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_cls = layout_type # Alias for compatibility
|
||||
self._layout_params = layout_params
|
||||
self._params = layout_params # Alias for compatibility
|
||||
self.device = qdata.device
|
||||
self.dtype = qdata.dtype
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **kwargs):
|
||||
layout_cls = get_layout_class(layout_type)
|
||||
if layout_cls is None:
|
||||
raise ValueError(f"Unknown layout type: {layout_type}")
|
||||
qdata, params = layout_cls.quantize(tensor, **kwargs)
|
||||
return cls(qdata, layout_type, params)
|
||||
|
||||
def dequantize(self):
|
||||
layout_cls = get_layout_class(self._layout_type)
|
||||
if layout_cls is None:
|
||||
return self._qdata
|
||||
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
dtype = kwargs.get("dtype", None)
|
||||
if len(args) > 0:
|
||||
if isinstance(args[0], (torch.device, str)):
|
||||
device = args[0]
|
||||
elif isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
|
||||
new_qdata = self._qdata.to(*args, **kwargs)
|
||||
new_params = self._layout_params.copy()
|
||||
if device is not None:
|
||||
for k, v in new_params.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params.__dict__[k] = v.to(device=device)
|
||||
|
||||
if dtype is not None:
|
||||
new_params.orig_dtype = dtype
|
||||
|
||||
return type(self)(new_qdata, self._layout_type, new_params)
|
||||
|
||||
def detach(self):
|
||||
return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy())
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy())
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
self._qdata.requires_grad_(requires_grad)
|
||||
return self
|
||||
|
||||
def numel(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
import math
|
||||
return math.prod(self._layout_params.orig_shape)
|
||||
return self._qdata.numel()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
return torch.Size(self._layout_params.orig_shape)
|
||||
return self._qdata.shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
def size(self, dim=None):
|
||||
if dim is None:
|
||||
return self.shape
|
||||
return self.shape[dim]
|
||||
|
||||
def dim(self):
|
||||
return self.ndim
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "params":
|
||||
return self._layout_params
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.empty_like:
|
||||
input_t = args[0]
|
||||
if isinstance(input_t, cls):
|
||||
dtype = kwargs.get("dtype", input_t.dtype)
|
||||
device = kwargs.get("device", input_t.device)
|
||||
return torch.empty(input_t.shape, dtype=dtype, device=device)
|
||||
|
||||
if func is torch.Tensor.copy_:
|
||||
dst, src = args[:2]
|
||||
if isinstance(src, cls):
|
||||
return dst.copy_(src.dequantize(), **kwargs)
|
||||
|
||||
return NotImplemented
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return NotImplemented
|
||||
|
||||
class QuantizedLayout:
|
||||
class Params:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def copy(self):
|
||||
return type(self)(**self.__dict__)
|
||||
|
||||
class _CKFp8Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
class TensorCoreNVFP4Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY = {}
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY[name] = cls
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
return _LOCAL_LAYOUT_REGISTRY.get(name)
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
def decorator(handler_func):
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
import comfy.float
|
||||
import comfy.mps_ops
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
@ -51,7 +177,13 @@ import comfy.float
|
||||
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
@ -83,6 +215,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
if qdata.device.type == "mps":
|
||||
if qdata.dtype == torch.uint8:
|
||||
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
|
||||
elif qdata.is_floating_point() and qdata.element_size() == 1:
|
||||
# It is MPS Float8. View as uint8.
|
||||
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
|
||||
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
@ -14,8 +14,9 @@ class JobStatus:
|
||||
IN_PROGRESS = 'in_progress'
|
||||
COMPLETED = 'completed'
|
||||
FAILED = 'failed'
|
||||
CANCELLED = 'cancelled'
|
||||
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
|
||||
status_info = history_item.get('status', {})
|
||||
status_str = status_info.get('status_str') if status_info else None
|
||||
if status_str == 'success':
|
||||
status = JobStatus.COMPLETED
|
||||
elif status_str == 'error':
|
||||
status = JobStatus.FAILED
|
||||
else:
|
||||
status = JobStatus.COMPLETED
|
||||
|
||||
outputs = history_item.get('outputs', {})
|
||||
outputs_count, preview_output = get_outputs_summary(outputs)
|
||||
@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
execution_error = None
|
||||
execution_start_time = None
|
||||
execution_end_time = None
|
||||
was_interrupted = False
|
||||
if status_info:
|
||||
messages = status_info.get('messages', [])
|
||||
for entry in messages:
|
||||
@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
execution_end_time = event_data.get('timestamp')
|
||||
if event_name == 'execution_error':
|
||||
execution_error = event_data
|
||||
elif event_name == 'execution_interrupted':
|
||||
was_interrupted = True
|
||||
|
||||
if status_str == 'success':
|
||||
status = JobStatus.COMPLETED
|
||||
elif status_str == 'error':
|
||||
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
|
||||
else:
|
||||
status = JobStatus.COMPLETED
|
||||
|
||||
job = prune_dict({
|
||||
'id': prompt_id,
|
||||
@ -268,13 +273,13 @@ def get_all_jobs(
|
||||
for item in queued:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
|
||||
|
||||
include_completed = JobStatus.COMPLETED in status_filter
|
||||
include_failed = JobStatus.FAILED in status_filter
|
||||
if include_completed or include_failed:
|
||||
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
|
||||
requested_history_statuses = history_statuses & set(status_filter)
|
||||
if requested_history_statuses:
|
||||
for prompt_id, history_item in history.items():
|
||||
is_failed = history_item.get('status', {}).get('status_str') == 'error'
|
||||
if (is_failed and include_failed) or (not is_failed and include_completed):
|
||||
jobs.append(normalize_history_item(prompt_id, history_item))
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
@ -19,6 +19,7 @@ class TestJobStatus:
|
||||
assert JobStatus.IN_PROGRESS == 'in_progress'
|
||||
assert JobStatus.COMPLETED == 'completed'
|
||||
assert JobStatus.FAILED == 'failed'
|
||||
assert JobStatus.CANCELLED == 'cancelled'
|
||||
|
||||
def test_all_contains_all_statuses(self):
|
||||
"""ALL should contain all status values."""
|
||||
@ -26,7 +27,8 @@ class TestJobStatus:
|
||||
assert JobStatus.IN_PROGRESS in JobStatus.ALL
|
||||
assert JobStatus.COMPLETED in JobStatus.ALL
|
||||
assert JobStatus.FAILED in JobStatus.ALL
|
||||
assert len(JobStatus.ALL) == 4
|
||||
assert JobStatus.CANCELLED in JobStatus.ALL
|
||||
assert len(JobStatus.ALL) == 5
|
||||
|
||||
|
||||
class TestIsPreviewable:
|
||||
@ -336,6 +338,40 @@ class TestNormalizeHistoryItem:
|
||||
assert job['execution_error']['node_type'] == 'KSampler'
|
||||
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
|
||||
|
||||
def test_cancelled_job(self):
|
||||
"""Cancelled/interrupted history item should have cancelled status."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-cancelled',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890000},
|
||||
['node1'],
|
||||
),
|
||||
'status': {
|
||||
'status_str': 'error',
|
||||
'completed': False,
|
||||
'messages': [
|
||||
('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}),
|
||||
('execution_interrupted', {
|
||||
'prompt_id': 'prompt-cancelled',
|
||||
'node_id': '5',
|
||||
'node_type': 'KSampler',
|
||||
'executed': ['1', '2', '3'],
|
||||
'timestamp': 1234567891000,
|
||||
})
|
||||
]
|
||||
},
|
||||
'outputs': {},
|
||||
}
|
||||
|
||||
job = normalize_history_item('prompt-cancelled', history_item)
|
||||
assert job['status'] == 'cancelled'
|
||||
assert job['execution_start_time'] == 1234567890500
|
||||
assert job['execution_end_time'] == 1234567891000
|
||||
# Cancelled jobs should not have execution_error set
|
||||
assert 'execution_error' not in job
|
||||
|
||||
def test_include_outputs(self):
|
||||
"""When include_outputs=True, should include full output data."""
|
||||
history_item = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user