From 353978a9b7e6aa0a166c6fd39f3adc49d42c8cb3 Mon Sep 17 00:00:00 2001 From: lax Date: Mon, 20 Apr 2026 06:55:48 +0000 Subject: [PATCH 1/2] Add SVDQuant W4A4 integration with comfy-kitchen (kitchen-native row-major) quant_ops.py: register TensorCoreSVDQuantW4A4Layout when comfy-kitchen exposes it; gate the kitchen CUDA backend on cuda >= 13 (the optimized kitchen CUDA ops are validated against cu13+ runtimes; on older cu the backend falls back to eager). ops.py: handle svdquant_w4a4 quant_format by loading weight_scale / proj_down / proj_up / smooth_factor into TensorCoreSVDQuantW4A4Layout.Params, with the img_mlp.net.2 / txt_mlp.net.2 fallback for act_unsigned. Targets the row-major kitchen-native kernels on feat/svdquant-w4a4-kitchen-native; the verbatim zgemm path is a sibling branch. --- comfy/ops.py | 57 ++++++++++++++++++++++++++++++++++++++-------- comfy/quant_ops.py | 31 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 7a9b4b84c..06be577f4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -997,6 +997,35 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_dtype=MixedPrecisionOps._compute_dtype, orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "svdquant_w4a4": + # SVDQuant W4A4: per-group weight scales + low-rank correction + # (proj_down, proj_up) + activation smoothing (smooth_factor) + wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + proj_down = self._load_scale_param(state_dict, prefix, "proj_down", device, manually_loaded_keys) + proj_up = self._load_scale_param(state_dict, prefix, "proj_up", device, manually_loaded_keys) + smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys) + act_unsigned = bool(layer_conf.get("act_unsigned", False)) + + # Early Qwen-Image conversion artifacts did not persist the + # fused GELU -> fc2 unsigned-activation flag. Those layers + # are the second linear in the feed-forward block. + if not act_unsigned and ( + layer_name.endswith(".img_mlp.net.2") or layer_name.endswith(".txt_mlp.net.2") + ): + act_unsigned = True + + if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)): + raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}") + + params = layout_cls.Params( + scale=wscales, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + proj_down=proj_down, + proj_up=proj_up, + smooth_factor=smooth_factor, + act_unsigned=act_unsigned, + ) else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") @@ -1046,6 +1075,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec quant_conf = {"format": self.quant_format} if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True + if bool(getattr(getattr(self.weight, "_params", None), "act_unsigned", False)): + quant_conf["act_unsigned"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) input_scale = getattr(self, 'input_scale', None) @@ -1103,18 +1134,24 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Inference path (unchanged) if _use_quantized: + # Some layouts (e.g. SVDQuant W4A4) do activation quantization + # inside their fused kernel and cannot pre-quantize a float + # tensor up-front. Skip the input wrapping for those. + layout_cls = get_layout_class(self.layout_type) + layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True) - # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input + if layout_quantizes_input: + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - # Fall back to non-quantized for non-2D tensors - if input_reshaped.ndim == 2: - reshaped_3d = input.ndim == 3 - # dtype is now implicit in the layout class - scale = getattr(self, 'input_scale', None) - if scale is not None: - scale = comfy.model_management.cast_to_device(scale, input.device, None) - input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb2..c6978771f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -37,6 +37,9 @@ except ImportError as e: class _CKNvfp4Layout: pass + class _CKSVDQuantW4A4Layout: + pass + def register_layout_class(name, cls): pass @@ -55,6 +58,16 @@ if not _CK_MXFP8_AVAILABLE: class _CKMxfp8Layout: pass +_CK_SVDQUANT_W4A4_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout + _CK_SVDQUANT_W4A4_AVAILABLE = True + except ImportError: + logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.") + class _CKSVDQuantW4A4Layout: + pass + import comfy.float # ============================================================================== @@ -162,6 +175,13 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e5m2 +# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the +# kitchen-registered layout class unchanged. Comfy-side extension reserved in +# case per-layer input scales or other Comfy-specific metadata are added later. +class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout): + pass + + # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout @@ -176,6 +196,8 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) +if _CK_SVDQUANT_W4A4_AVAILABLE: + register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -204,6 +226,14 @@ if _CK_MXFP8_AVAILABLE: "group_size": 32, } +if _CK_SVDQUANT_W4A4_AVAILABLE: + QUANT_ALGOS["svdquant_w4a4"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"}, + "comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout", + "group_size": 64, + } + # ============================================================================== # Re-exports for backward compatibility @@ -216,6 +246,7 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorCoreSVDQuantW4A4Layout", "QUANT_ALGOS", "register_layout_op", ] From 3ddcc095ed07da148da7df98744966b9e04ad75f Mon Sep 17 00:00:00 2001 From: lax Date: Sat, 25 Apr 2026 19:37:25 +0000 Subject: [PATCH 2/2] Add AWQ W4A16 (modulation) integration with comfy-kitchen MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires comfy-kitchen's TensorCoreAWQW4A16Layout (introduced on feat/awq-w4a16-modulation) into ComfyUI's MixedPrecisionOps so checkpoints that tag modulation linears with comfy_quant.format = "awq_w4a16" get their (qweight, weight_scale, weight_zero) loaded into the kitchen layout class instead of being dequantized to bf16 plain Linear at conversion time. quant_ops.py: - detect TensorCoreAWQW4A16Layout availability and stub it out for the no-kitchen fallback (mirrors the SVDQuant W4A4 pattern) - register the layout class + add "awq_w4a16" to QUANT_ALGOS (storage_t = int8 packed uint4, parameters = {weight_scale, weight_zero}, default group_size = 64) ops.py: add the awq_w4a16 branch in MixedPrecisionOps.Linear._load_from_state_dict that constructs Params(scale, zeros, group_size, ...) and wraps qweight into a QuantizedTensor — F.linear then dispatches to ck.gemv_awq_w4a16 via the layout's aten handlers. Pairs with comfy-kitchen feat/awq-w4a16-modulation. Targets the ~10 GB inflation in Qwen-Image-Edit kitchen-native checkpoints, where the modulation linears (img_mod.1 / txt_mod.1) currently dominate disk + VRAM because they're materialized as plain bf16 Linear during conversion. --- comfy/ops.py | 16 ++++++++++++++++ comfy/quant_ops.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 06be577f4..878d35a55 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1026,6 +1026,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec smooth_factor=smooth_factor, act_unsigned=act_unsigned, ) + elif self.quant_format == "awq_w4a16": + # AWQ W4A16: int4 weight, fp16/bf16 activation. Used for + # the modulation linears (img_mod.1 / txt_mod.1) so they + # stay int4 in checkpoint + VRAM rather than getting + # dequantized to bf16 at conversion time (~10 GB saving). + wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys) + if wscales is None or wzeros is None: + raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}") + params = layout_cls.Params( + scale=wscales, + zeros=wzeros, + group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c6978771f..4e76ea49c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -40,6 +40,9 @@ except ImportError as e: class _CKSVDQuantW4A4Layout: pass + class _CKAWQW4A16Layout: + pass + def register_layout_class(name, cls): pass @@ -68,6 +71,16 @@ if _CK_AVAILABLE: class _CKSVDQuantW4A4Layout: pass +_CK_AWQ_W4A16_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout + _CK_AWQ_W4A16_AVAILABLE = True + except ImportError: + logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.") + class _CKAWQW4A16Layout: + pass + import comfy.float # ============================================================================== @@ -182,6 +195,14 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout): pass +# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen +# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 / +# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the +# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback. +class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout): + pass + + # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout @@ -198,6 +219,8 @@ if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) if _CK_SVDQUANT_W4A4_AVAILABLE: register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout) +if _CK_AWQ_W4A16_AVAILABLE: + register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -234,6 +257,14 @@ if _CK_SVDQUANT_W4A4_AVAILABLE: "group_size": 64, } +if _CK_AWQ_W4A16_AVAILABLE: + QUANT_ALGOS["awq_w4a16"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale", "weight_zero"}, + "comfy_tensor_layout": "TensorCoreAWQW4A16Layout", + "group_size": 64, + } + # ============================================================================== # Re-exports for backward compatibility @@ -242,6 +273,7 @@ if _CK_SVDQUANT_W4A4_AVAILABLE: __all__ = [ "QuantizedTensor", "QuantizedLayout", + "TensorCoreAWQW4A16Layout", "TensorCoreFP8Layout", "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout",