mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 18:47:29 +08:00
Add AWQ W4A16 (modulation) integration with comfy-kitchen
Wires comfy-kitchen's TensorCoreAWQW4A16Layout (introduced on
feat/awq-w4a16-modulation) into ComfyUI's MixedPrecisionOps so checkpoints
that tag modulation linears with comfy_quant.format = "awq_w4a16" get
their (qweight, weight_scale, weight_zero) loaded into the kitchen layout
class instead of being dequantized to bf16 plain Linear at conversion time.
quant_ops.py:
- detect TensorCoreAWQW4A16Layout availability and stub it out for the
no-kitchen fallback (mirrors the SVDQuant W4A4 pattern)
- register the layout class + add "awq_w4a16" to QUANT_ALGOS
(storage_t = int8 packed uint4, parameters = {weight_scale, weight_zero},
default group_size = 64)
ops.py: add the awq_w4a16 branch in MixedPrecisionOps.Linear._load_from_state_dict
that constructs Params(scale, zeros, group_size, ...) and wraps qweight
into a QuantizedTensor — F.linear then dispatches to ck.gemv_awq_w4a16
via the layout's aten handlers.
Pairs with comfy-kitchen feat/awq-w4a16-modulation. Targets the ~10 GB
inflation in Qwen-Image-Edit kitchen-native checkpoints, where the
modulation linears (img_mod.1 / txt_mod.1) currently dominate disk + VRAM
because they're materialized as plain bf16 Linear during conversion.
This commit is contained in:
parent
353978a9b7
commit
3ddcc095ed
16
comfy/ops.py
16
comfy/ops.py
@ -1026,6 +1026,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
smooth_factor=smooth_factor,
|
smooth_factor=smooth_factor,
|
||||||
act_unsigned=act_unsigned,
|
act_unsigned=act_unsigned,
|
||||||
)
|
)
|
||||||
|
elif self.quant_format == "awq_w4a16":
|
||||||
|
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used for
|
||||||
|
# the modulation linears (img_mod.1 / txt_mod.1) so they
|
||||||
|
# stay int4 in checkpoint + VRAM rather than getting
|
||||||
|
# dequantized to bf16 at conversion time (~10 GB saving).
|
||||||
|
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
|
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
||||||
|
if wscales is None or wzeros is None:
|
||||||
|
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=wscales,
|
||||||
|
zeros=wzeros,
|
||||||
|
group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
|
|||||||
@ -40,6 +40,9 @@ except ImportError as e:
|
|||||||
class _CKSVDQuantW4A4Layout:
|
class _CKSVDQuantW4A4Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class _CKAWQW4A16Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
def register_layout_class(name, cls):
|
def register_layout_class(name, cls):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,6 +71,16 @@ if _CK_AVAILABLE:
|
|||||||
class _CKSVDQuantW4A4Layout:
|
class _CKSVDQuantW4A4Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
_CK_AWQ_W4A16_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
||||||
|
_CK_AWQ_W4A16_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.")
|
||||||
|
class _CKAWQW4A16Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -182,6 +195,14 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen
|
||||||
|
# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 /
|
||||||
|
# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the
|
||||||
|
# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback.
|
||||||
|
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - default to E4M3
|
# Backward compatibility alias - default to E4M3
|
||||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||||
|
|
||||||
@ -198,6 +219,8 @@ if _CK_MXFP8_AVAILABLE:
|
|||||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||||
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
|
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
|
||||||
|
if _CK_AWQ_W4A16_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -234,6 +257,14 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
|
|||||||
"group_size": 64,
|
"group_size": 64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_AWQ_W4A16_AVAILABLE:
|
||||||
|
QUANT_ALGOS["awq_w4a16"] = {
|
||||||
|
"storage_t": torch.int8,
|
||||||
|
"parameters": {"weight_scale", "weight_zero"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
|
||||||
|
"group_size": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
@ -242,6 +273,7 @@ if _CK_SVDQUANT_W4A4_AVAILABLE:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"QuantizedTensor",
|
"QuantizedTensor",
|
||||||
"QuantizedLayout",
|
"QuantizedLayout",
|
||||||
|
"TensorCoreAWQW4A16Layout",
|
||||||
"TensorCoreFP8Layout",
|
"TensorCoreFP8Layout",
|
||||||
"TensorCoreFP8E4M3Layout",
|
"TensorCoreFP8E4M3Layout",
|
||||||
"TensorCoreFP8E5M2Layout",
|
"TensorCoreFP8E5M2Layout",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user