This commit is contained in:
HK416-TYPED 2026-06-22 19:13:16 +03:30 committed by GitHub
commit a15f252b29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 14 deletions

View File

@ -1071,9 +1071,20 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if module.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
if module.quant_format not in QUANT_ALGOS:
raise ValueError(
f"Quantization format '{module.quant_format}' for layer {layer_name} "
f"is not available in this build (supported: {sorted(QUANT_ALGOS.keys())}). "
"Update comfy_kitchen to enable it."
)
qconfig = QUANT_ALGOS[module.quant_format]
module.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(module.layout_type)
module._layout_cls = layout_cls
# W4A16-style layouts keep the activation in compute dtype; the forward
# path reads this to decide whether to quantize the input.
module._layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True)
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
@ -1089,6 +1100,35 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
elif module.quant_format == "svdquant_w4a4":
# SVDQuant W4A4: per-group weight scales + low-rank correction
# (proj_down, proj_up) + activation smoothing (smooth_factor).
wscales = pop_scale("weight_scale")
proj_down = pop_scale("proj_down")
proj_up = pop_scale("proj_up")
smooth_factor = pop_scale("smooth_factor")
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
scales = {
"scale": wscales,
"proj_down": proj_down,
"proj_up": proj_up,
"smooth_factor": smooth_factor,
"act_unsigned": bool(layer_conf.get("act_unsigned", False)),
}
elif module.quant_format == "awq_w4a16":
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used by
# Qwen-Image-Edit modulation linears so they stay packed instead of
# being dequantized to bf16 at load time.
wscales = pop_scale("weight_scale")
wzeros = pop_scale("weight_zero")
if wscales is None or wzeros is None:
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
scales = {
"scale": wscales,
"zeros": wzeros,
"group_size": int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
}
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
@ -1178,7 +1218,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
# Preserve the SVDQuant W4A4 act_unsigned flag on round-trip save.
_params = getattr(getattr(self, 'weight', None), '_params', None)
extra_quant_conf = {"act_unsigned": True} if getattr(_params, 'act_unsigned', False) else None
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf=extra_quant_conf, extra_quant_params=("input_scale",))
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -1228,18 +1271,18 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Inference path (unchanged)
if _use_quantized:
if getattr(self, "_layout_quantizes_input", True):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))

View File

@ -20,8 +20,14 @@ try:
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
# cu<13 lacks the block-scale FP4 cuBLASLt APIs but not the int4
# MMA or fp8 paths. Kitchen's per-op FunctionConstraints already
# gate scaled_mm_nvfp4 behind HAS_CUBLASLT, so we keep the CUDA
# backend enabled for svdquant_w4a4 / fp8 / mxfp8 / rope.
logging.warning(
"cuda_version=%s < 13: NVFP4 cuBLAS path unavailable; "
"other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.",
".".join(map(str, cuda_version)))
if args.enable_triton_backend:
try:
@ -47,6 +53,12 @@ except ImportError as e:
class _CKNvfp4Layout:
pass
class _CKSVDQuantW4A4Layout:
pass
class _CKAWQW4A16Layout:
pass
def register_layout_class(name, cls):
pass
@ -65,6 +77,30 @@ if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
_CK_SVDQUANT_W4A4_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout
_CK_SVDQUANT_W4A4_AVAILABLE = True
except ImportError:
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
if not _CK_SVDQUANT_W4A4_AVAILABLE:
class _CKSVDQuantW4A4Layout:
pass
_CK_AWQ_W4A16_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
_CK_AWQ_W4A16_AVAILABLE = True
except ImportError:
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will not be supported.")
if not _CK_AWQ_W4A16_AVAILABLE:
class _CKAWQW4A16Layout:
pass
import comfy.float
# ==============================================================================
@ -172,6 +208,19 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e5m2
# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the
# kitchen-registered layout class unchanged. Comfy-side extension reserved in
# case per-layer input scales or other Comfy-specific metadata are added later.
class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
pass
# AWQ W4A16 — pre-quantized offline modulation linears. Kitchen owns the
# tensor subclass dispatch and gemv implementation; ComfyUI only loads params.
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
pass
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
@ -186,6 +235,10 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
if _CK_SVDQUANT_W4A4_AVAILABLE:
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
if _CK_AWQ_W4A16_AVAILABLE:
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -214,6 +267,22 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32,
}
if _CK_SVDQUANT_W4A4_AVAILABLE:
QUANT_ALGOS["svdquant_w4a4"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"},
"comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout",
"group_size": 64,
}
if _CK_AWQ_W4A16_AVAILABLE:
QUANT_ALGOS["awq_w4a16"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale", "weight_zero"},
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
"group_size": 64,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -226,6 +295,8 @@ __all__ = [
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"TensorCoreSVDQuantW4A4Layout",
"TensorCoreAWQW4A16Layout",
"QUANT_ALGOS",
"register_layout_op",
]