diff --git a/comfy/ops.py b/comfy/ops.py index 06be577f4..878d35a55 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1026,6 +1026,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec smooth_factor=smooth_factor, act_unsigned=act_unsigned, ) + elif self.quant_format == "awq_w4a16": + # AWQ W4A16: int4 weight, fp16/bf16 activation. Used for + # the modulation linears (img_mod.1 / txt_mod.1) so they + # stay int4 in checkpoint + VRAM rather than getting + # dequantized to bf16 at conversion time (~10 GB saving). + wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys) + if wscales is None or wzeros is None: + raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}") + params = layout_cls.Params( + scale=wscales, + zeros=wzeros, + group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c6978771f..4e76ea49c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -40,6 +40,9 @@ except ImportError as e: class _CKSVDQuantW4A4Layout: pass + class _CKAWQW4A16Layout: + pass + def register_layout_class(name, cls): pass @@ -68,6 +71,16 @@ if _CK_AVAILABLE: class _CKSVDQuantW4A4Layout: pass +_CK_AWQ_W4A16_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout + _CK_AWQ_W4A16_AVAILABLE = True + except ImportError: + logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.") + class _CKAWQW4A16Layout: + pass + import comfy.float # ============================================================================== @@ -182,6 +195,14 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout): pass +# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen +# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 / +# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the +# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback. +class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout): + pass + + # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout @@ -198,6 +219,8 @@ if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) if _CK_SVDQUANT_W4A4_AVAILABLE: register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout) +if _CK_AWQ_W4A16_AVAILABLE: + register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -234,6 +257,14 @@ if _CK_SVDQUANT_W4A4_AVAILABLE: "group_size": 64, } +if _CK_AWQ_W4A16_AVAILABLE: + QUANT_ALGOS["awq_w4a16"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale", "weight_zero"}, + "comfy_tensor_layout": "TensorCoreAWQW4A16Layout", + "group_size": 64, + } + # ============================================================================== # Re-exports for backward compatibility @@ -242,6 +273,7 @@ if _CK_SVDQUANT_W4A4_AVAILABLE: __all__ = [ "QuantizedTensor", "QuantizedLayout", + "TensorCoreAWQW4A16Layout", "TensorCoreFP8Layout", "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout",