Compare commits

...

11 Commits

Author SHA1 Message Date
Macpaul Lin
bb98e5a5ea
Merge 38f5db0118 into ec0a832acb 2026-01-09 17:58:19 +09:00
Jedrzej Kosinski
ec0a832acb
Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 22:49:12 -08:00
ric-yu
04c49a29b4
feat: add cancelled filter to /jobs (#11680) 2026-01-08 21:57:36 -08:00
Macpaul Lin
38f5db0118 fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ea3ec049bd fix(quant_ops): implement __torch_function__ to support torch.empty_like for mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
96803b16c0 fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
9907a5e4f5 fix(quant_ops): add numel, size, shape, dim, and ndim to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
e3cc20034d fix(quant_ops): add _layout_cls and _params aliases to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
77a46c68ea fix(quant_ops): add detach, clone, and requires_grad_ to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
406dab2d53 fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
Macpaul Lin
ef7b4a717a feat(mps): implement native-like Float8 support via LUT dequantization
Add a new MPS-specific operations module to handle Float8 tensor support
on Apple Silicon. Since MPS does not natively support Float8 dtypes, this
implementation uses a uint8 storage strategy combined with a GPU-accelerated
Lookup Table (LUT) for efficient dequantization, keeping data on the GPU.

- Add comfy/mps_ops.py: Implement cached LUT generation and index-based
  dequantization for MPS.
- Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8
  when moving to MPS, and route dequantization to mps_ops.
- Modify comfy/float.py: Add CPU staging for stochastic rounding to
  prevent MPS casting errors during quantization.
- Modify comfy/quant_ops.py: Add fallback for fp8_linear.

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
2026-01-09 02:00:41 +08:00
6 changed files with 301 additions and 24 deletions

View File

@ -81,7 +81,8 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets

View File

@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
# MPS workaround: perform float8 conversion on CPU
target_device = value.device
use_cpu_staging = (target_device.type == "mps")
output_device = "cpu" if use_cpu_staging else target_device
output = torch.empty_like(value, dtype=dtype, device=output_device)
generator = torch.Generator(device=target_device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
if use_cpu_staging:
res = res.cpu()
output[i:i+slice_size].copy_(res)
if use_cpu_staging:
return output.to(target_device)
return output
return value.to(dtype=dtype)

77
comfy/mps_ops.py Normal file
View File

@ -0,0 +1,77 @@
import torch
_LUT_CACHE = {}
def get_lut(dtype, device):
"""
Get or create a lookup table for float8 dequantization on MPS.
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
"""
key = (dtype, device)
if key in _LUT_CACHE:
return _LUT_CACHE[key]
# Generate all possible 8-bit values (0-255)
# We create them on CPU first as float8, then cast to float16, then move to MPS.
# This acts as our decoding table.
# Create uint8 pattern 0..255
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
# View as the target float8 type
# Note: We must use .view() on a tensor that has the same number of bytes.
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
# but we can create the float8 tensor from bytes.
# Actually, the easiest way to generate the LUT is:
# 1. Create bytes 0..255
# 2. View as float8 (on CPU, where it is supported)
# 3. Convert to float16 (on CPU)
# 4. Move float16 LUT to MPS
try:
f8_tensor = byte_pattern.view(dtype)
f16_lut = f8_tensor.to(torch.float16)
# Move to the requested MPS device
lut = f16_lut.to(device)
_LUT_CACHE[key] = lut
return lut
except Exception as e:
print(f"Failed to create MPS LUT for {dtype}: {e}")
# Fallback: return None or raise
raise e
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
"""
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
Args:
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
scale: Tensor (scalar)
orig_dtype: The target dtype (e.g. float16)
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
Returns:
Tensor of shape (...) with dtype=orig_dtype
"""
lut = get_lut(float8_dtype, qdata.device)
# Use index_select or advanced indexing.
# Advanced indexing lut[qdata.long()] is generally efficient.
# We explicitly cast to long (int64) for indexing.
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
if lut.dtype != orig_dtype:
lut = lut.to(dtype=orig_dtype)
output = lut[qdata.long()]
# Apply scale
# Scale might need to be cast to orig_dtype too
if isinstance(scale, torch.Tensor):
scale = scale.to(dtype=orig_dtype)
output.mul_(scale)
return output

View File

@ -28,22 +28,148 @@ except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class ck_dummy:
@staticmethod
def quantize_per_tensor_fp8(tensor, scale, dtype):
return (tensor / scale.to(tensor.device)).to(dtype)
ck = ck_dummy
class QuantizedTensor:
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_cls = layout_type # Alias for compatibility
self._layout_params = layout_params
self._params = layout_params # Alias for compatibility
self.device = qdata.device
self.dtype = qdata.dtype
@classmethod
def from_float(cls, tensor, layout_type, **kwargs):
layout_cls = get_layout_class(layout_type)
if layout_cls is None:
raise ValueError(f"Unknown layout type: {layout_type}")
qdata, params = layout_cls.quantize(tensor, **kwargs)
return cls(qdata, layout_type, params)
def dequantize(self):
layout_cls = get_layout_class(self._layout_type)
if layout_cls is None:
return self._qdata
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
def to(self, *args, **kwargs):
device = kwargs.get("device", None)
dtype = kwargs.get("dtype", None)
if len(args) > 0:
if isinstance(args[0], (torch.device, str)):
device = args[0]
elif isinstance(args[0], torch.dtype):
dtype = args[0]
new_qdata = self._qdata.to(*args, **kwargs)
new_params = self._layout_params.copy()
if device is not None:
for k, v in new_params.__dict__.items():
if isinstance(v, torch.Tensor):
new_params.__dict__[k] = v.to(device=device)
if dtype is not None:
new_params.orig_dtype = dtype
return type(self)(new_qdata, self._layout_type, new_params)
def detach(self):
return type(self)(self._qdata.detach(), self._layout_type, self._layout_params.copy())
def clone(self):
return type(self)(self._qdata.clone(), self._layout_type, self._layout_params.copy())
def requires_grad_(self, requires_grad=True):
self._qdata.requires_grad_(requires_grad)
return self
def numel(self):
if hasattr(self._layout_params, "orig_shape"):
import math
return math.prod(self._layout_params.orig_shape)
return self._qdata.numel()
@property
def shape(self):
if hasattr(self._layout_params, "orig_shape"):
return torch.Size(self._layout_params.orig_shape)
return self._qdata.shape
@property
def ndim(self):
return len(self.shape)
def size(self, dim=None):
if dim is None:
return self.shape
return self.shape[dim]
def dim(self):
return self.ndim
def __getattr__(self, name):
if name == "params":
return self._layout_params
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.empty_like:
input_t = args[0]
if isinstance(input_t, cls):
dtype = kwargs.get("dtype", input_t.dtype)
device = kwargs.get("device", input_t.device)
return torch.empty(input_t.shape, dtype=dtype, device=device)
if func is torch.Tensor.copy_:
dst, src = args[:2]
if isinstance(src, cls):
return dst.copy_(src.dequantize(), **kwargs)
return NotImplemented
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return NotImplemented
class QuantizedLayout:
class Params:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def copy(self):
return type(self)(**self.__dict__)
class _CKFp8Layout(QuantizedLayout):
pass
class _CKFp8Layout:
class TensorCoreNVFP4Layout(QuantizedLayout):
pass
class TensorCoreNVFP4Layout:
pass
_LOCAL_LAYOUT_REGISTRY = {}
def register_layout_class(name, cls):
pass
_LOCAL_LAYOUT_REGISTRY[name] = cls
def get_layout_class(name):
return None
return _LOCAL_LAYOUT_REGISTRY.get(name)
def register_layout_op(torch_op, layout_type):
def decorator(handler_func):
return handler_func
return decorator
import comfy.float
import comfy.mps_ops
# ==============================================================================
# FP8 Layouts with Comfy-Specific Extensions
@ -51,7 +177,13 @@ import comfy.float
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
@ -83,6 +215,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
if qdata.device.type == "mps":
if qdata.dtype == torch.uint8:
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
elif qdata.is_floating_point() and qdata.element_size() == 1:
# It is MPS Float8. View as uint8.
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

View File

@ -14,8 +14,9 @@ class JobStatus:
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
CANCELLED = 'cancelled'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Media types that can be previewed in the frontend
@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_error = None
execution_start_time = None
execution_end_time = None
was_interrupted = False
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
elif event_name == 'execution_interrupted':
was_interrupted = True
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
else:
status = JobStatus.COMPLETED
job = prune_dict({
'id': prompt_id,
@ -268,13 +273,13 @@ def get_all_jobs(
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]

View File

@ -19,6 +19,7 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed'
assert JobStatus.CANCELLED == 'cancelled'
def test_all_contains_all_statuses(self):
"""ALL should contain all status values."""
@ -26,7 +27,8 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4
assert JobStatus.CANCELLED in JobStatus.ALL
assert len(JobStatus.ALL) == 5
class TestIsPreviewable:
@ -336,6 +338,40 @@ class TestNormalizeHistoryItem:
assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_cancelled_job(self):
"""Cancelled/interrupted history item should have cancelled status."""
history_item = {
'prompt': (
5,
'prompt-cancelled',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}),
('execution_interrupted', {
'prompt_id': 'prompt-cancelled',
'node_id': '5',
'node_type': 'KSampler',
'executed': ['1', '2', '3'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-cancelled', history_item)
assert job['status'] == 'cancelled'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
# Cancelled jobs should not have execution_error set
assert 'execution_error' not in job
def test_include_outputs(self):
"""When include_outputs=True, should include full output data."""
history_item = {