Add AWQ W4A16 (modulation) integration with comfy-kitchen

Wires comfy-kitchen's TensorCoreAWQW4A16Layout (introduced on
feat/awq-w4a16-modulation) into ComfyUI's MixedPrecisionOps so checkpoints
that tag modulation linears with comfy_quant.format = "awq_w4a16" get
their (qweight, weight_scale, weight_zero) loaded into the kitchen layout
class instead of being dequantized to bf16 plain Linear at conversion time.

quant_ops.py:
- detect TensorCoreAWQW4A16Layout availability and stub it out for the
  no-kitchen fallback (mirrors the SVDQuant W4A4 pattern)
- register the layout class + add "awq_w4a16" to QUANT_ALGOS
  (storage_t = int8 packed uint4, parameters = {weight_scale, weight_zero},
   default group_size = 64)

ops.py: add the awq_w4a16 branch in MixedPrecisionOps.Linear._load_from_state_dict
that constructs Params(scale, zeros, group_size, ...) and wraps qweight
into a QuantizedTensor — F.linear then dispatches to ck.gemv_awq_w4a16
via the layout's aten handlers.

Pairs with comfy-kitchen feat/awq-w4a16-modulation. Targets the ~10 GB
inflation in Qwen-Image-Edit kitchen-native checkpoints, where the
modulation linears (img_mod.1 / txt_mod.1) currently dominate disk + VRAM
because they're materialized as plain bf16 Linear during conversion.
This commit is contained in:
lax 2026-04-25 19:37:25 +00:00
parent 353978a9b7
commit 3ddcc095ed
2 changed files with 48 additions and 0 deletions

View File

@ -1026,6 +1026,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
smooth_factor=smooth_factor,
act_unsigned=act_unsigned,
)
elif self.quant_format == "awq_w4a16":
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used for
# the modulation linears (img_mod.1 / txt_mod.1) so they
# stay int4 in checkpoint + VRAM rather than getting
# dequantized to bf16 at conversion time (~10 GB saving).
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
if wscales is None or wzeros is None:
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
params = layout_cls.Params(
scale=wscales,
zeros=wzeros,
group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")

View File

@ -40,6 +40,9 @@ except ImportError as e:
class _CKSVDQuantW4A4Layout:
pass
class _CKAWQW4A16Layout:
pass
def register_layout_class(name, cls):
pass
@ -68,6 +71,16 @@ if _CK_AVAILABLE:
class _CKSVDQuantW4A4Layout:
pass
_CK_AWQ_W4A16_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
_CK_AWQ_W4A16_AVAILABLE = True
except ImportError:
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.")
class _CKAWQW4A16Layout:
pass
import comfy.float
# ==============================================================================
@ -182,6 +195,14 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
pass
# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen
# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 /
# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the
# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback.
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
pass
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
@ -198,6 +219,8 @@ if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
if _CK_SVDQUANT_W4A4_AVAILABLE:
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
if _CK_AWQ_W4A16_AVAILABLE:
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -234,6 +257,14 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
"group_size": 64,
}
if _CK_AWQ_W4A16_AVAILABLE:
QUANT_ALGOS["awq_w4a16"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale", "weight_zero"},
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
"group_size": 64,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -242,6 +273,7 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
__all__ = [
"QuantizedTensor",
"QuantizedLayout",
"TensorCoreAWQW4A16Layout",
"TensorCoreFP8Layout",
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",