diff --git a/comfy/ops.py b/comfy/ops.py index 3f088a962..b9cd277cb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1071,9 +1071,20 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat if module.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") + if module.quant_format not in QUANT_ALGOS: + raise ValueError( + f"Quantization format '{module.quant_format}' for layer {layer_name} " + f"is not available in this build (supported: {sorted(QUANT_ALGOS.keys())}). " + "Update comfy_kitchen to enable it." + ) + qconfig = QUANT_ALGOS[module.quant_format] module.layout_type = qconfig["comfy_tensor_layout"] layout_cls = get_layout_class(module.layout_type) + module._layout_cls = layout_cls + # W4A16-style layouts keep the activation in compute dtype; the forward + # path reads this to decide whether to quantize the input. + module._layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True) # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8. if module.quant_format in ("float8_e4m3fn", "float8_e5m2"): @@ -1089,6 +1100,35 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat if ts is None or bs is None: raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") scales = {"scale": ts, "block_scale": bs} + elif module.quant_format == "svdquant_w4a4": + # SVDQuant W4A4: per-group weight scales + low-rank correction + # (proj_down, proj_up) + activation smoothing (smooth_factor). + wscales = pop_scale("weight_scale") + proj_down = pop_scale("proj_down") + proj_up = pop_scale("proj_up") + smooth_factor = pop_scale("smooth_factor") + if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)): + raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}") + scales = { + "scale": wscales, + "proj_down": proj_down, + "proj_up": proj_up, + "smooth_factor": smooth_factor, + "act_unsigned": bool(layer_conf.get("act_unsigned", False)), + } + elif module.quant_format == "awq_w4a16": + # AWQ W4A16: int4 weight, fp16/bf16 activation. Used by + # Qwen-Image-Edit modulation linears so they stay packed instead of + # being dequantized to bf16 at load time. + wscales = pop_scale("weight_scale") + wzeros = pop_scale("weight_zero") + if wscales is None or wzeros is None: + raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}") + scales = { + "scale": wscales, + "zeros": wzeros, + "group_size": int(layer_conf.get("group_size", qconfig.get("group_size", 64))), + } else: raise ValueError(f"Unsupported quantization format: {module.quant_format}") @@ -1178,7 +1218,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = destination if destination is not None else {} - return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",)) + # Preserve the SVDQuant W4A4 act_unsigned flag on round-trip save. + _params = getattr(getattr(self, 'weight', None), '_params', None) + extra_quant_conf = {"act_unsigned": True} if getattr(_params, 'act_unsigned', False) else None + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf=extra_quant_conf, extra_quant_params=("input_scale",)) def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -1228,18 +1271,18 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Inference path (unchanged) if _use_quantized: + if getattr(self, "_layout_quantizes_input", True): + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - - # Fall back to non-quantized for non-2D tensors - if input_reshaped.ndim == 2: - reshaped_3d = input.ndim == 3 - # dtype is now implicit in the layout class - scale = getattr(self, 'input_scale', None) - if scale is not None: - scale = comfy.model_management.cast_to_device(scale, input.device, None) - input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b90bcfd25..cae5f1180 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,8 +20,14 @@ try: else: cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) if cuda_version < (13,): - ck.registry.disable("cuda") - logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") + # cu<13 lacks the block-scale FP4 cuBLASLt APIs but not the int4 + # MMA or fp8 paths. Kitchen's per-op FunctionConstraints already + # gate scaled_mm_nvfp4 behind HAS_CUBLASLT, so we keep the CUDA + # backend enabled for svdquant_w4a4 / fp8 / mxfp8 / rope. + logging.warning( + "cuda_version=%s < 13: NVFP4 cuBLAS path unavailable; " + "other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.", + ".".join(map(str, cuda_version))) if args.enable_triton_backend: try: @@ -47,6 +53,12 @@ except ImportError as e: class _CKNvfp4Layout: pass + class _CKSVDQuantW4A4Layout: + pass + + class _CKAWQW4A16Layout: + pass + def register_layout_class(name, cls): pass @@ -65,6 +77,30 @@ if not _CK_MXFP8_AVAILABLE: class _CKMxfp8Layout: pass +_CK_SVDQUANT_W4A4_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout + _CK_SVDQUANT_W4A4_AVAILABLE = True + except ImportError: + logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.") + +if not _CK_SVDQUANT_W4A4_AVAILABLE: + class _CKSVDQuantW4A4Layout: + pass + +_CK_AWQ_W4A16_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout + _CK_AWQ_W4A16_AVAILABLE = True + except ImportError: + logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will not be supported.") + +if not _CK_AWQ_W4A16_AVAILABLE: + class _CKAWQW4A16Layout: + pass + import comfy.float # ============================================================================== @@ -172,6 +208,19 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e5m2 +# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the +# kitchen-registered layout class unchanged. Comfy-side extension reserved in +# case per-layer input scales or other Comfy-specific metadata are added later. +class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout): + pass + + +# AWQ W4A16 — pre-quantized offline modulation linears. Kitchen owns the +# tensor subclass dispatch and gemv implementation; ComfyUI only loads params. +class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout): + pass + + # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout @@ -186,6 +235,10 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) +if _CK_SVDQUANT_W4A4_AVAILABLE: + register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout) +if _CK_AWQ_W4A16_AVAILABLE: + register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -214,6 +267,22 @@ if _CK_MXFP8_AVAILABLE: "group_size": 32, } +if _CK_SVDQUANT_W4A4_AVAILABLE: + QUANT_ALGOS["svdquant_w4a4"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"}, + "comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout", + "group_size": 64, + } + +if _CK_AWQ_W4A16_AVAILABLE: + QUANT_ALGOS["awq_w4a16"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale", "weight_zero"}, + "comfy_tensor_layout": "TensorCoreAWQW4A16Layout", + "group_size": 64, + } + # ============================================================================== # Re-exports for backward compatibility @@ -226,6 +295,8 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorCoreSVDQuantW4A4Layout", + "TensorCoreAWQW4A16Layout", "QUANT_ALGOS", "register_layout_op", ]