From 6da00dd899e3ee6f2a0a8163b080a9f373395025 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:48:58 -0800 Subject: [PATCH 01/11] Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635) --------- Co-authored-by: Jedrzej Kosinski --- .github/workflows/test-build.yml | 2 +- .github/workflows/test-launch.yml | 4 +- comfy/model_management.py | 4 +- comfy/ops.py | 164 +++-- comfy/quant_ops.py | 641 +++--------------- requirements.txt | 1 + .../comfy_quant/test_mixed_precision.py | 12 +- tests-unit/comfy_quant/test_quant_registry.py | 190 ------ 8 files changed, 223 insertions(+), 795 deletions(-) delete mode 100644 tests-unit/comfy_quant/test_quant_registry.py diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..9160242e9 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index fd70aff23..ef0d3f123 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -32,7 +32,9 @@ jobs: working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | - if grep -qE "Exception|Error" console_output.log; then + grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log + cat console_output_filtered.log + if grep -qE "Exception|Error" console_output_filtered.log; then echo "Unhandled exception/error found in server log." exit 1 fi diff --git a/comfy/model_management.py b/comfy/model_management.py index 7f5a8aee9..22f4de044 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1156,7 +1156,7 @@ def pin_memory(tensor): if not tensor.is_contiguous(): return False - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1183,7 +1183,7 @@ def unpin_memory(tensor): return False ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes size_stored = PINNED_MEMORY.get(ptr, None) if size_stored is None: diff --git a/comfy/ops.py b/comfy/ops.py index 16889bb82..f5e1e9230 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -412,26 +412,34 @@ def fp8_linear(self, input): return None input_dtype = input.dtype + input_shape = input.shape + tensor_3d = input.ndim == 3 - if input.ndim == 3 or input.ndim == 2: - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) + if input.ndim != 2: + return None + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - # Wrap weight in QuantizedTensor - this enables unified dispatch - # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! - layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) - o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.to(dtype).contiguous() + layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) + quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input) - uncast_bias_weight(self, w, bias, offload_stream) - return o + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - return None + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_3d: + o = o.reshape((input_shape[0], input_shape[1], w.shape[0])) + + return o class fp8_ops(manual_cast): class Linear(manual_cast.Linear): @@ -477,7 +485,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, QUANT_ALGOS +from .quant_ops import ( + QuantizedTensor, + QUANT_ALGOS, + TensorCoreFP8Layout, + get_layout_class, +) def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): @@ -497,14 +510,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - if dtype is None: - dtype = MixedPrecisionOps._compute_dtype - - self.factory_kwargs = {"device": device, "dtype": dtype} + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features - self._has_bias = bias + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm @@ -512,6 +526,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -529,14 +553,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - dtype = self.factory_kwargs["dtype"] - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) - if dtype != MixedPrecisionOps._compute_dtype: - self.comfy_cast_weights = True - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) - else: - self.register_parameter("bias", None) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) if not self._full_precision_mm: @@ -547,31 +564,46 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) - weight_scale_key = f"{prefix}weight_scale" - scale = state_dict.pop(weight_scale_key, None) - if scale is not None: - scale = scale.to(device) - layout_params = { - 'scale': scale, - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } + # Load format-specific parameters + if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: + # FP8: single tensor scale + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if scale is not None: - manually_loaded_keys.append(weight_scale_key) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "nvfp4": + # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.float8_e4m3fn) + + if tensor_scale is None or block_scale is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + else: + raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) - else: - self.register_parameter("bias", None) - for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue # Already handled above + param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -588,7 +620,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) if isinstance(self.weight, QuantizedTensor): - sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + layout_cls = self.weight._layout_cls + + # Check if it's any FP8 variant (E4M3 or E5M2) + if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): + sd["{}weight_scale".format(prefix)] = self.weight._params.scale + elif layout_cls == "TensorCoreNVFP4Layout": + sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale + sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale + quant_conf = {"format": self.quant_format} if self._full_precision_mm: quant_conf["full_precision_matrix_mult"] = True @@ -607,12 +647,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def forward(self, input, *args, **kwargs): run_every_op() + input_shape = input.shape + tensor_3d = input.ndim == 3 + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) + + if input.ndim != 2: + # Fall back to comfy_cast_weights for non-2D tensors + return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + + # dtype is now implicit in the layout class + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) + + output = self._forward(input, self.weight, self.bias) + + # Reshape output back to 3D if input was 3D + if tensor_3d: + output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + + return output def convert_weight(self, weight, inplace=False, **kwargs): if isinstance(weight, QuantizedTensor): @@ -622,7 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + # dtype is now implicit in the layout class + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..cd737726f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,580 +1,133 @@ import torch import logging -from typing import Tuple, Dict + +try: + import comfy_kitchen as ck + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + TensorCoreFP8Layout as _CKFp8Layout, + TensorCoreNVFP4Layout, # Direct import, no wrapper needed + register_layout_op, + register_layout_class, + get_layout_class, + ) + _CK_AVAILABLE = True + ck.registry.disable("triton") + for k, v in ck.list_backends().items(): + logging.info(f"Found comfy_kitchen backend {k}: {v}") +except ImportError as e: + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class TensorCoreNVFP4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None + import comfy.float -_LAYOUT_REGISTRY = {} -_GENERIC_UTILS = {} - - -def register_layout_op(torch_op, layout_type): - """ - Decorator to register a layout-specific operation handler. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) - layout_type: Layout class (e.g., TensorCoreFP8Layout) - Example: - @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) - def fp8_linear(func, args, kwargs): - # FP8-specific linear implementation - ... - """ - def decorator(handler_func): - if torch_op not in _LAYOUT_REGISTRY: - _LAYOUT_REGISTRY[torch_op] = {} - _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func - return handler_func - return decorator - - -def register_generic_util(torch_op): - """ - Decorator to register a generic utility that works for all layouts. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - - Example: - @register_generic_util(torch.ops.aten.detach.default) - def generic_detach(func, args, kwargs): - # Works for any layout - ... - """ - def decorator(handler_func): - _GENERIC_UTILS[torch_op] = handler_func - return handler_func - return decorator - - -def _get_layout_from_args(args): - for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type - return None - - -def _move_layout_params_to_device(params, device): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.to(device=device) - else: - new_params[k] = v - return new_params - - -def _copy_layout_params(params): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.clone() - else: - new_params[k] = v - return new_params - -def _copy_layout_params_inplace(src, dst, non_blocking=False): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - dst[k].copy_(v, non_blocking=non_blocking) - else: - dst[k] = v - -class QuantizedLayout: - """ - Base class for quantization layouts. - - A layout encapsulates the format-specific logic for quantization/dequantization - and provides a uniform interface for extracting raw tensors needed for computation. - - New quantization formats should subclass this and implement the required methods. - """ - @classmethod - def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: - raise NotImplementedError(f"{cls.__name__} must implement quantize()") - - @staticmethod - def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError("TensorLayout must implement dequantize()") - - @classmethod - def get_plain_tensors(cls, qtensor) -> torch.Tensor: - raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") - - -class QuantizedTensor(torch.Tensor): - """ - Universal quantized tensor that works with any layout. - - This tensor subclass uses a pluggable layout system to support multiple - quantization formats (FP8, INT4, INT8, etc.) without code duplication. - - The layout_type determines format-specific behavior, while common operations - (detach, clone, to) are handled generically. - - Attributes: - _qdata: The quantized tensor data - _layout_type: Layout class (e.g., TensorCoreFP8Layout) - _layout_params: Dict with layout-specific params (scale, zero_point, etc.) - """ - - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. - - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) - - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params - - def __repr__(self): - layout_name = self._layout_type - param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) - return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - - @property - def layout_type(self): - return self._layout_type - - def __tensor_flatten__(self): - """ - Tensor flattening protocol for proper device movement. - """ - inner_tensors = ["_qdata"] - ctx = { - "layout_type": self._layout_type, - } - - tensor_params = {} - non_tensor_params = {} - for k, v in self._layout_params.items(): - if isinstance(v, torch.Tensor): - tensor_params[k] = v - else: - non_tensor_params[k] = v - - ctx["tensor_param_keys"] = list(tensor_params.keys()) - ctx["non_tensor_params"] = non_tensor_params - - for k, v in tensor_params.items(): - attr_name = f"_layout_param_{k}" - object.__setattr__(self, attr_name, v) - inner_tensors.append(attr_name) - - return inner_tensors, ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): - """ - Tensor unflattening protocol for proper device movement. - Reconstructs the QuantizedTensor after device movement. - """ - layout_type = ctx["layout_type"] - layout_params = dict(ctx["non_tensor_params"]) - - for key in ctx["tensor_param_keys"]: - attr_name = f"_layout_param_{key}" - layout_params[key] = inner_tensors[attr_name] - - return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) - - @classmethod - def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) - return cls(qdata, layout_type, layout_params) - - def dequantize(self) -> torch.Tensor: - return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Step 1: Check generic utilities first (detach, clone, to, etc.) - if func in _GENERIC_UTILS: - return _GENERIC_UTILS[func](func, args, kwargs) - - # Step 2: Check layout-specific handlers (linear, matmul, etc.) - layout_type = _get_layout_from_args(args) - if layout_type and func in _LAYOUT_REGISTRY: - handler = _LAYOUT_REGISTRY[func].get(layout_type) - if handler: - return handler(func, args, kwargs) - - # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) - - @classmethod - def _dequant_and_fallback(cls, func, args, kwargs): - def dequant_arg(arg): - if isinstance(arg, QuantizedTensor): - return arg.dequantize() - elif isinstance(arg, (list, tuple)): - return type(arg)(dequant_arg(a) for a in arg) - return arg - - new_args = dequant_arg(args) - new_kwargs = dequant_arg(kwargs) - return func(*new_args, **new_kwargs) - - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self, *arg, **kwargs): - return self._qdata.is_contiguous(*arg, **kwargs) - - def storage(self): - return self._qdata.storage() - # ============================================================================== -# Generic Utilities (Layout-Agnostic Operations) +# FP8 Layouts with Comfy-Specific Extensions # ============================================================================== -def _create_transformed_qtensor(qt, transform_fn): - new_data = transform_fn(qt._qdata) - new_params = _copy_layout_params(qt._layout_params) - return QuantizedTensor(new_data, qt._layout_type, new_params) +class _TensorCoreFP8LayoutBase(_CKFp8Layout): + FP8_DTYPE = None # Must be overridden in subclass - -def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_layout is not None and target_layout != torch.strided: - logging.warning( - f"QuantizedTensor: layout change requested to {target_layout}, " - f"but not supported. Ignoring layout." - ) - - # Handle device transfer - current_device = qt._qdata.device - if target_device is not None: - # Normalize device for comparison - if isinstance(target_device, str): - target_device = torch.device(target_device) - if isinstance(current_device, str): - current_device = torch.device(current_device) - - if target_device != current_device: - logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - if target_dtype is not None: - new_params["orig_dtype"] = target_dtype - new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) - logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") - return new_qt - - logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") - return qt - - -@register_generic_util(torch.ops.aten.detach.default) -def generic_detach(func, args, kwargs): - """Detach operation - creates a detached copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.detach()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.clone.default) -def generic_clone(func, args, kwargs): - """Clone operation - creates a deep copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.clone()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._to_copy.default) -def generic_to_copy(func, args, kwargs): - """Device/dtype transfer operation - handles .to(device) calls.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - op_name="_to_copy" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype_layout) -def generic_to_dtype_layout(func, args, kwargs): - """Handle .to(device) calls using the dtype_layout variant.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - target_layout=kwargs.get('layout', None), - op_name="to" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.copy_.default) -def generic_copy_(func, args, kwargs): - qt_dest = args[0] - src = args[1] - non_blocking = args[2] if len(args) > 2 else False - if isinstance(qt_dest, QuantizedTensor): - if isinstance(src, QuantizedTensor): - # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) - qt_dest._layout_type = src._layout_type - orig_dtype = qt_dest._layout_params["orig_dtype"] - _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) - qt_dest._layout_params["orig_dtype"] = orig_dtype - else: - # Copy from regular tensor - just copy raw data - qt_dest._qdata.copy_(src) - return qt_dest - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) -def generic_has_compatible_shallow_copy_type(func, args, kwargs): - return True - - -@register_generic_util(torch.ops.aten.empty_like.default) -def generic_empty_like(func, args, kwargs): - """Empty_like operation - creates an empty tensor with the same quantized structure.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - # Create empty tensor with same shape and dtype as the quantized data - hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) - new_qdata = torch.empty_like(qt._qdata, **kwargs) - - # Handle device transfer for layout params - target_device = kwargs.get('device', new_qdata.device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - - # Update orig_dtype if dtype is specified - new_params['orig_dtype'] = hp_dtype - - return QuantizedTensor(new_qdata, qt._layout_type, new_params) - return func(*args, **kwargs) - -# ============================================================================== -# FP8 Layout + Operation Handlers -# ============================================================================== -class TensorCoreFP8Layout(QuantizedLayout): - """ - Storage format: - - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) - - scale: Scalar tensor (float32) for dequantization - - orig_dtype: Original dtype before quantization (for casting back) - """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if cls.FP8_DTYPE is None: + raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small tensor_info = torch.finfo(tensor.dtype) scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if scale is not None: - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is None: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) else: tensor = tensor * (1.0 / scale).to(tensor.dtype) + qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) else: - scale = torch.ones((), device=tensor.device, dtype=torch.float32) + qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) - if stochastic_rounding > 0: - tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) - else: - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) - tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) + return qdata, params - layout_params = { - 'scale': scale, - 'orig_dtype': orig_dtype - } - return tensor, layout_params - @staticmethod - def dequantize(qdata, scale, orig_dtype, **kwargs): - plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - plain_tensor.mul_(scale) - return plain_tensor +class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e4m3fn - @classmethod - def get_plain_tensors(cls, qtensor): - return qtensor._qdata, qtensor._layout_params['scale'] + +class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e5m2 + + +# Backward compatibility alias - default to E4M3 +TensorCoreFP8Layout = TensorCoreFP8E4M3Layout + + +# ============================================================================== +# Registry +# ============================================================================== + +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, - "comfy_tensor_layout": "TensorCoreFP8Layout", + "comfy_tensor_layout": "TensorCoreFP8E4M3Layout", + }, + "float8_e5m2": { + "storage_t": torch.float8_e5m2, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8E5M2Layout", + }, + "nvfp4": { + "storage_t": torch.uint8, + "parameters": {"weight_scale", "weight_scale_2", "input_scale"}, + "comfy_tensor_layout": "TensorCoreNVFP4Layout", + "group_size": 16, }, } -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, -} +# ============================================================================== +# Re-exports for backward compatibility +# ============================================================================== -@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") -def fp8_linear(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - weight_t = plain_weight.t() - - tensor_2d = False - if len(plain_input.shape) == 2: - tensor_2d = True - plain_input = plain_input.unsqueeze(1) - - input_shape = plain_input.shape - if len(input_shape) != 3: - return None - - try: - output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]).contiguous(), - weight_t, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - - if not tensor_2d: - output = output.reshape((-1, input_shape[1], weight.shape[0])) - - if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = scale_a * scale_b - output_params = { - 'scale': output_scale, - 'orig_dtype': input_tensor._layout_params['orig_dtype'] - } - return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) - else: - return output - - except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - - # Case 2: DQ Fallback - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - if isinstance(input_tensor, QuantizedTensor): - input_tensor = input_tensor.dequantize() - - return torch.nn.functional.linear(input_tensor, weight, bias) - -def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output - -@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") -def fp8_addmm(func, args, kwargs): - input_tensor = args[1] - weight = args[2] - bias = args[0] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - if isinstance(args[2], QuantizedTensor): - a[2] = args[2].dequantize() - - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") -def fp8_mm(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") -@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") -def fp8_func(func, args, kwargs): - input_tensor = args[0] - if isinstance(input_tensor, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - ar = list(args) - ar[0] = plain_input - return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) - return func(*args, **kwargs) +__all__ = [ + "QuantizedTensor", + "QuantizedLayout", + "TensorCoreFP8Layout", + "TensorCoreFP8E4M3Layout", + "TensorCoreFP8E5M2Layout", + "TensorCoreNVFP4Layout", + "QUANT_ALGOS", + "register_layout_op", +] diff --git a/requirements.txt b/requirements.txt index 3a05799eb..0ee152032 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 +comfy-kitchen>=0.2.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 3a54941e6..7b2eac940 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify scales were loaded - self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) - self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) + self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py deleted file mode 100644 index 9cb54ede8..000000000 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -def has_gpu(): - return torch.cuda.is_available() - -from comfy.cli_args import args -if not has_gpu(): - args.cpu = True - -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout - - -class TestQuantizedTensor(unittest.TestCase): - """Test the QuantizedTensor subclass with FP8 layout""" - - def test_creation(self): - """Test creating a QuantizedTensor with TensorCoreFP8Layout""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._layout_params['scale'], scale) - self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") - - def test_dequantize(self): - """Test explicit dequantization""" - - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - dequantized = qt.dequantize() - - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_from_float(self): - """Test creating QuantizedTensor from float tensor""" - float_tensor = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - - qt = QuantizedTensor.from_float( - float_tensor, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt.shape, (64, 32)) - - # Verify dequantization gives approximately original values - dequantized = qt.dequantize() - mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.1) - - -class TestGenericUtilities(unittest.TestCase): - """Test generic utility operations""" - - def test_detach(self): - """Test detach operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Detach should return a new QuantizedTensor - qt_detached = qt.detach() - - self.assertIsInstance(qt_detached, QuantizedTensor) - self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") - - def test_clone(self): - """Test clone operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Clone should return a new QuantizedTensor - qt_cloned = qt.clone() - - self.assertIsInstance(qt_cloned, QuantizedTensor) - self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") - - # Verify it's a deep copy - self.assertIsNot(qt_cloned._qdata, qt._qdata) - - @unittest.skipUnless(has_gpu(), "GPU not available") - def test_to_device(self): - """Test device transfer""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Moving to same device should work (CPU to CPU) - qt_cpu = qt.to('cpu') - - self.assertIsInstance(qt_cpu, QuantizedTensor) - self.assertEqual(qt_cpu.device.type, 'cpu') - self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') - - -class TestTensorCoreFP8Layout(unittest.TestCase): - """Test the TensorCoreFP8Layout implementation""" - - def test_quantize(self): - """Test quantization method""" - float_tensor = torch.randn(32, 64, dtype=torch.float32) - scale = torch.tensor(1.5) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertEqual(qdata.dtype, torch.float8_e4m3fn) - self.assertEqual(qdata.shape, float_tensor.shape) - self.assertIn('scale', layout_params) - self.assertIn('orig_dtype', layout_params) - self.assertEqual(layout_params['orig_dtype'], torch.float32) - - def test_dequantize(self): - """Test dequantization method""" - float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 - scale = torch.tensor(1.0) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - - # Should approximately match original - self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_q = QuantizedTensor.from_float( - a_fp32, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, QuantizedTensor) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -if __name__ == "__main__": - unittest.main() From 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:50:35 -0800 Subject: [PATCH 02/11] Use rope functions from comfy kitchen. (#11647) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From 161800241117fae7af90e0c938d0cf8cb2f2ddb1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 5 Jan 2026 20:07:39 -0800 Subject: [PATCH 03/11] Revert "Use rope functions from comfy kitchen. (#11647)" (#11648) This reverts commit 6ef85c49151cf8c4d6bf5e7ccfc566b8d0681cbd. --- comfy/ldm/flux/math.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index f9597de5b..6a22df8bc 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,7 +4,6 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -14,6 +13,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,20 +28,13 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) +def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) -try: - import comfy.quant_ops - apply_rope = comfy.quant_ops.ck.apply_rope - apply_rope1 = comfy.quant_ops.ck.apply_rope1 -except: - logging.warning("No comfy kitchen, using old apply_rope functions.") - def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) - return x_out.reshape(*x.shape).type_as(x) - - def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) From e14f3b661069971163ddc56036b0f486933b9162 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 6 Jan 2026 14:37:11 +0800 Subject: [PATCH 04/11] chore: update workflow templates to v0.7.66 (#11652) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ee152032..9c9c0e29e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.65 +comfyui-workflow-templates==0.7.66 comfyui-embedded-docs==0.3.1 torch torchsde From 96e0d0924e027248733bc6e0b8102dcdc8acde33 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 11:43:24 -0800 Subject: [PATCH 05/11] Add helpful message to portable. (#11671) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 2 +- .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 2 +- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause From 6ffc159bdd56d1ad73e954081def6a7f163e7a7f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:53:43 -0800 Subject: [PATCH 06/11] Update comfy-kitchen version to 0.2.1 (#11672) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9c9c0e29e..22cb50e2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.0 +comfy-kitchen>=0.2.1 #non essential dependencies: kornia>=0.7.1 From c3c3e93c5bb3034175c17ef8beeb8fe8626c66ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 13:57:50 -0800 Subject: [PATCH 07/11] Use rope functions from comfy kitchen. (#11674) --- comfy/ldm/flux/math.py | 23 +++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/requirements.txt b/requirements.txt index 22cb50e2d..7798cb179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.1 +comfy-kitchen>=0.2.2 #non essential dependencies: kornia>=0.7.1 From c3566c0d765200068d26d0888f035504a50012f2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 7 Jan 2026 06:28:29 +0800 Subject: [PATCH 08/11] chore: update workflow templates to v0.7.67 (#11667) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7798cb179..caad0026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.66 +comfyui-workflow-templates==0.7.67 comfyui-embedded-docs==0.3.1 torch torchsde From 023cf13721cac256c323e2226319b766d07b1f36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:33:03 -0800 Subject: [PATCH 09/11] Fix lowvram issue with ltxv2 text encoder. (#11675) --- comfy/ldm/lightricks/embeddings_connector.py | 2 +- comfy/text_encoders/lt.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index f7a43f3c3..06f5ada89 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -276,7 +276,7 @@ class Embeddings1DConnector(nn.Module): max(1024, hidden_states.shape[1]) / self.num_learnable_registers ) learnable_registers = torch.tile( - self.learnable_registers, (num_registers_duplications, 1) + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) ) hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 2c2d453e8..e5964e42b 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -86,17 +86,19 @@ class LTXAVTEModel(torch.nn.Module): ) def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) self.gemma3_12b.set_clip_options(options) def reset_clip_options(self): self.gemma3_12b.reset_clip_options() + self.execution_device = None def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device - out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) From 6e9ee55cdd9e0eca6b5144063575b983f3311762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:41:27 -0800 Subject: [PATCH 10/11] Disable ltxav previews. (#11676) --- comfy/latent_formats.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9bbe30b53..cb4f52ce1 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -408,7 +408,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] class LTXAV(LTXV): - pass + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None class HunyuanVideo(LatentFormat): latent_channels = 16 From 2c03884f5fb7fa213161dfe1e9a09a8e8c4b6062 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:07:26 -0800 Subject: [PATCH 11/11] Skip fp4 matrix mult on devices that don't support it. (#11677) --- comfy/model_management.py | 10 ++++++++++ comfy/ops.py | 21 +++++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 22f4de044..928282092 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index f5e1e9230..8f9fdce36 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -493,11 +493,12 @@ from .quant_ops import ( ) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None @@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) return sd @@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and