diff --git a/comfy/ops.py b/comfy/ops.py index 7e34831e0..f5e1e9230 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -488,11 +488,8 @@ if CUBLAS_IS_AVAILABLE: from .quant_ops import ( QuantizedTensor, QUANT_ALGOS, - LAYOUTS, TensorCoreFP8Layout, - TensorCoreFP8E4M3Layout, - TensorCoreFP8E5M2Layout, - TensorCoreNVFP4Layout + get_layout_class, ) @@ -567,7 +564,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = LAYOUTS[self.layout_type] + layout_cls = get_layout_class(self.layout_type) # Load format-specific parameters if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: @@ -599,7 +596,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), layout_cls, params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) @@ -626,10 +623,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layout_cls = self.weight._layout_cls # Check if it's any FP8 variant (E4M3 or E5M2) - if layout_cls in (TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout) or \ - layout_cls.__name__ in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): + if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): sd["{}weight_scale".format(prefix)] = self.weight._params.scale - elif layout_cls == TensorCoreNVFP4Layout or layout_cls.__name__ == "TensorCoreNVFP4Layout": + elif layout_cls == "TensorCoreNVFP4Layout": sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale @@ -659,7 +655,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor)): - layout_cls = LAYOUTS[self.layout_type] # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) if tensor_3d: @@ -670,7 +665,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) # dtype is now implicit in the layout class - input = QuantizedTensor.from_float(input, layout_cls, scale=getattr(self, 'input_scale', None)) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) output = self._forward(input, self.weight, self.bias) @@ -688,9 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - layout_cls = LAYOUTS[self.layout_type] # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, layout_cls, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 72dc4b3e1..b1f8ac010 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -6,65 +6,39 @@ from typing import Dict try: import comfy_kitchen as ck from comfy_kitchen.tensor import ( - QuantizedTensor as _CKQuantizedTensor, + QuantizedTensor, QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout, # Direct import, no wrapper needed register_layout_op, + register_layout_class, + get_layout_class, ) _CK_AVAILABLE = True + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") except ImportError as e: - logging.info(f"Failed to import comfy_kitchen, falling back to torch ops. Error: {e}") + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") _CK_AVAILABLE = False - raise ImportError(f"comfy_kitchen is required but not available: {e}") + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class TensorCoreNVFP4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None import comfy.float - -# ============================================================================== -# Backward Compatibility Layer -# ============================================================================== - -class QuantizedTensor(_CKQuantizedTensor): - @staticmethod - def __new__(cls, qdata, layout_cls, params): - # Backward compat: Convert string layout names and dict params before __new__ - if isinstance(layout_cls, str): - layout_cls = LAYOUTS[layout_cls] - - if isinstance(params, dict): - params = layout_cls.Params(**params) - - return _CKQuantizedTensor.__new__(cls, qdata, layout_cls, params) - - def __init__(self, qdata, layout_cls, params): - super().__init__(qdata, layout_cls, params) - - @property - def _layout_params(self) -> Dict: - return dataclasses.asdict(self._params) - - @property - def _layout_type(self) -> str: - return self._layout_cls.__name__ - - @property - def layout_type(self) -> str: - """Backward compatibility alias for _layout_type.""" - return self._layout_type - - def _copy_with(self, qdata=None, params=None, clone_params=True): - if params is None: - params = self._params.clone() if clone_params else self._params - return type(self)( - qdata if qdata is not None else self._qdata, - self._layout_cls, - params, - ) - - # ============================================================================== # FP8 Layouts with Comfy-Specific Extensions # ============================================================================== @@ -81,7 +55,10 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()) / torch.finfo(cls.FP8_DTYPE).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) if scale is None: scale = torch.ones((), device=tensor.device, dtype=torch.float32) @@ -97,7 +74,7 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): else: qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) - params = cls.Params(scale=scale, orig_dtype=orig_dtype, orig_shape=orig_shape) + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) return qdata, params @@ -117,12 +94,10 @@ TensorCoreFP8Layout = TensorCoreFP8E4M3Layout # Registry # ============================================================================== -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, # Backward compat alias (E4M3) - "TensorCoreFP8E4M3Layout": TensorCoreFP8E4M3Layout, - "TensorCoreFP8E5M2Layout": TensorCoreFP8E5M2Layout, - "TensorCoreNVFP4Layout": TensorCoreNVFP4Layout, # Direct from comfy_kitchen -} +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) QUANT_ALGOS = { "float8_e4m3fn": {