mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Add AWQ W4A16 (modulation) integration with comfy-kitchen
Wires comfy-kitchen's TensorCoreAWQW4A16Layout (introduced on
feat/awq-w4a16-modulation) into ComfyUI's MixedPrecisionOps so checkpoints
that tag modulation linears with comfy_quant.format = "awq_w4a16" get
their (qweight, weight_scale, weight_zero) loaded into the kitchen layout
class instead of being dequantized to bf16 plain Linear at conversion time.
quant_ops.py:
- detect TensorCoreAWQW4A16Layout availability and stub it out for the
no-kitchen fallback (mirrors the SVDQuant W4A4 pattern)
- register the layout class + add "awq_w4a16" to QUANT_ALGOS
(storage_t = int8 packed uint4, parameters = {weight_scale, weight_zero},
default group_size = 64)
ops.py: add the awq_w4a16 branch in MixedPrecisionOps.Linear._load_from_state_dict
that constructs Params(scale, zeros, group_size, ...) and wraps qweight
into a QuantizedTensor — F.linear then dispatches to ck.gemv_awq_w4a16
via the layout's aten handlers.
Pairs with comfy-kitchen feat/awq-w4a16-modulation. Targets the ~10 GB
inflation in Qwen-Image-Edit kitchen-native checkpoints, where the
modulation linears (img_mod.1 / txt_mod.1) currently dominate disk + VRAM
because they're materialized as plain bf16 Linear during conversion.
This commit is contained in:
parent
353978a9b7
commit
3ddcc095ed
16
comfy/ops.py
16
comfy/ops.py
@ -1026,6 +1026,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
smooth_factor=smooth_factor,
|
||||
act_unsigned=act_unsigned,
|
||||
)
|
||||
elif self.quant_format == "awq_w4a16":
|
||||
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used for
|
||||
# the modulation linears (img_mod.1 / txt_mod.1) so they
|
||||
# stay int4 in checkpoint + VRAM rather than getting
|
||||
# dequantized to bf16 at conversion time (~10 GB saving).
|
||||
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
||||
if wscales is None or wzeros is None:
|
||||
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
|
||||
params = layout_cls.Params(
|
||||
scale=wscales,
|
||||
zeros=wzeros,
|
||||
group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
|
||||
@ -40,6 +40,9 @@ except ImportError as e:
|
||||
class _CKSVDQuantW4A4Layout:
|
||||
pass
|
||||
|
||||
class _CKAWQW4A16Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -68,6 +71,16 @@ if _CK_AVAILABLE:
|
||||
class _CKSVDQuantW4A4Layout:
|
||||
pass
|
||||
|
||||
_CK_AWQ_W4A16_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
||||
_CK_AWQ_W4A16_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.")
|
||||
class _CKAWQW4A16Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@ -182,6 +195,14 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
|
||||
pass
|
||||
|
||||
|
||||
# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen
|
||||
# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 /
|
||||
# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the
|
||||
# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback.
|
||||
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
||||
pass
|
||||
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
|
||||
@ -198,6 +219,8 @@ if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
|
||||
if _CK_AWQ_W4A16_AVAILABLE:
|
||||
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@ -234,6 +257,14 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
"group_size": 64,
|
||||
}
|
||||
|
||||
if _CK_AWQ_W4A16_AVAILABLE:
|
||||
QUANT_ALGOS["awq_w4a16"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale", "weight_zero"},
|
||||
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
|
||||
"group_size": 64,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -242,6 +273,7 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
__all__ = [
|
||||
"QuantizedTensor",
|
||||
"QuantizedLayout",
|
||||
"TensorCoreAWQW4A16Layout",
|
||||
"TensorCoreFP8Layout",
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user