mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 10:42:59 +08:00
Address AWQ ComfyUI review feedback
This commit is contained in:
parent
96e5287a72
commit
b6f438db65
33
comfy/ops.py
33
comfy/ops.py
@ -951,9 +951,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
if self.quant_format is None:
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
|
if self.quant_format not in QUANT_ALGOS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantization format '{self.quant_format}' for layer {layer_name} "
|
||||||
|
f"is not available in this build (supported: {sorted(QUANT_ALGOS.keys())}). "
|
||||||
|
"Update comfy_kitchen to enable it."
|
||||||
|
)
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
self._layout_cls = get_layout_class(self.layout_type)
|
||||||
|
self._layout_quantizes_input = getattr(self._layout_cls, "QUANTIZES_INPUT", True)
|
||||||
|
layout_cls = self._layout_cls
|
||||||
|
|
||||||
# Load format-specific parameters
|
# Load format-specific parameters
|
||||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||||
@ -1006,14 +1014,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys)
|
smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys)
|
||||||
act_unsigned = bool(layer_conf.get("act_unsigned", False))
|
act_unsigned = bool(layer_conf.get("act_unsigned", False))
|
||||||
|
|
||||||
# Early Qwen-Image conversion artifacts did not persist the
|
|
||||||
# fused GELU -> fc2 unsigned-activation flag. Those layers
|
|
||||||
# are the second linear in the feed-forward block.
|
|
||||||
if not act_unsigned and (
|
|
||||||
layer_name.endswith(".img_mlp.net.2") or layer_name.endswith(".txt_mlp.net.2")
|
|
||||||
):
|
|
||||||
act_unsigned = True
|
|
||||||
|
|
||||||
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
|
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
|
||||||
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
|
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
|
||||||
|
|
||||||
@ -1027,10 +1027,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
act_unsigned=act_unsigned,
|
act_unsigned=act_unsigned,
|
||||||
)
|
)
|
||||||
elif self.quant_format == "awq_w4a16":
|
elif self.quant_format == "awq_w4a16":
|
||||||
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used for
|
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used by
|
||||||
# the modulation linears (img_mod.1 / txt_mod.1) so they
|
# Qwen-Image-Edit modulation linears so they stay packed
|
||||||
# stay int4 in checkpoint + VRAM rather than getting
|
# instead of being dequantized to bf16 at load time.
|
||||||
# dequantized to bf16 at conversion time (~10 GB saving).
|
|
||||||
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
||||||
if wscales is None or wzeros is None:
|
if wscales is None or wzeros is None:
|
||||||
@ -1150,13 +1149,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
# Inference path (unchanged)
|
# Inference path (unchanged)
|
||||||
if _use_quantized:
|
if _use_quantized:
|
||||||
# Some layouts (e.g. SVDQuant W4A4) do activation quantization
|
if getattr(self, "_layout_quantizes_input", True):
|
||||||
# inside their fused kernel and cannot pre-quantize a float
|
|
||||||
# tensor up-front. Skip the input wrapping for those.
|
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
|
||||||
layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True)
|
|
||||||
|
|
||||||
if layout_quantizes_input:
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
|
|
||||||
|
|||||||
@ -18,8 +18,14 @@ try:
|
|||||||
else:
|
else:
|
||||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||||
if cuda_version < (13,):
|
if cuda_version < (13,):
|
||||||
ck.registry.disable("cuda")
|
# cu<13 lacks the block-scale FP4 cuBLASLt APIs but not the int4
|
||||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
# MMA or fp8 paths. Kitchen's per-op FunctionConstraints already
|
||||||
|
# gate scaled_mm_nvfp4 behind HAS_CUBLASLT, so we keep the CUDA
|
||||||
|
# backend enabled for svdquant_w4a4 / fp8 / mxfp8 / rope.
|
||||||
|
logging.warning(
|
||||||
|
"cuda_version=%s < 13: NVFP4 cuBLAS path unavailable; "
|
||||||
|
"other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.",
|
||||||
|
".".join(map(str, cuda_version)))
|
||||||
|
|
||||||
ck.registry.disable("triton")
|
ck.registry.disable("triton")
|
||||||
for k, v in ck.list_backends().items():
|
for k, v in ck.list_backends().items():
|
||||||
@ -68,8 +74,10 @@ if _CK_AVAILABLE:
|
|||||||
_CK_SVDQUANT_W4A4_AVAILABLE = True
|
_CK_SVDQUANT_W4A4_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
|
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
|
||||||
class _CKSVDQuantW4A4Layout:
|
|
||||||
pass
|
if not _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||||
|
class _CKSVDQuantW4A4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
_CK_AWQ_W4A16_AVAILABLE = False
|
_CK_AWQ_W4A16_AVAILABLE = False
|
||||||
if _CK_AVAILABLE:
|
if _CK_AVAILABLE:
|
||||||
@ -77,9 +85,11 @@ if _CK_AVAILABLE:
|
|||||||
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
||||||
_CK_AWQ_W4A16_AVAILABLE = True
|
_CK_AWQ_W4A16_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.")
|
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will not be supported.")
|
||||||
class _CKAWQW4A16Layout:
|
|
||||||
pass
|
if not _CK_AWQ_W4A16_AVAILABLE:
|
||||||
|
class _CKAWQW4A16Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
@ -195,10 +205,8 @@ class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen
|
# AWQ W4A16 — pre-quantized offline modulation linears. Kitchen owns the
|
||||||
# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 /
|
# tensor subclass dispatch and gemv implementation; ComfyUI only loads params.
|
||||||
# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the
|
|
||||||
# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback.
|
|
||||||
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -273,12 +281,12 @@ if _CK_AWQ_W4A16_AVAILABLE:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"QuantizedTensor",
|
"QuantizedTensor",
|
||||||
"QuantizedLayout",
|
"QuantizedLayout",
|
||||||
"TensorCoreAWQW4A16Layout",
|
|
||||||
"TensorCoreFP8Layout",
|
"TensorCoreFP8Layout",
|
||||||
"TensorCoreFP8E4M3Layout",
|
"TensorCoreFP8E4M3Layout",
|
||||||
"TensorCoreFP8E5M2Layout",
|
"TensorCoreFP8E5M2Layout",
|
||||||
"TensorCoreNVFP4Layout",
|
"TensorCoreNVFP4Layout",
|
||||||
"TensorCoreSVDQuantW4A4Layout",
|
"TensorCoreSVDQuantW4A4Layout",
|
||||||
|
"TensorCoreAWQW4A16Layout",
|
||||||
"QUANT_ALGOS",
|
"QUANT_ALGOS",
|
||||||
"register_layout_op",
|
"register_layout_op",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user