Add SVDQuant W4A4 integration with comfy-kitchen (kitchen-native row-major)

quant_ops.py: register TensorCoreSVDQuantW4A4Layout when comfy-kitchen exposes
it; gate the kitchen CUDA backend on cuda >= 13 (the optimized kitchen CUDA
ops are validated against cu13+ runtimes; on older cu the backend falls back
to eager).

ops.py: handle svdquant_w4a4 quant_format by loading weight_scale / proj_down /
proj_up / smooth_factor into TensorCoreSVDQuantW4A4Layout.Params, with the
img_mlp.net.2 / txt_mlp.net.2 fallback for act_unsigned. Targets the row-major
kitchen-native kernels on feat/svdquant-w4a4-kitchen-native; the verbatim
zgemm path is a sibling branch.
This commit is contained in:
lax 2026-04-20 06:55:48 +00:00
parent 3086026401
commit 353978a9b7
2 changed files with 78 additions and 10 deletions

View File

@ -997,6 +997,35 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "svdquant_w4a4":
# SVDQuant W4A4: per-group weight scales + low-rank correction
# (proj_down, proj_up) + activation smoothing (smooth_factor)
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
proj_down = self._load_scale_param(state_dict, prefix, "proj_down", device, manually_loaded_keys)
proj_up = self._load_scale_param(state_dict, prefix, "proj_up", device, manually_loaded_keys)
smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys)
act_unsigned = bool(layer_conf.get("act_unsigned", False))
# Early Qwen-Image conversion artifacts did not persist the
# fused GELU -> fc2 unsigned-activation flag. Those layers
# are the second linear in the feed-forward block.
if not act_unsigned and (
layer_name.endswith(".img_mlp.net.2") or layer_name.endswith(".txt_mlp.net.2")
):
act_unsigned = True
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
params = layout_cls.Params(
scale=wscales,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
proj_down=proj_down,
proj_up=proj_up,
smooth_factor=smooth_factor,
act_unsigned=act_unsigned,
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
@ -1046,6 +1075,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
if bool(getattr(getattr(self.weight, "_params", None), "act_unsigned", False)):
quant_conf["act_unsigned"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
@ -1103,18 +1134,24 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Inference path (unchanged)
if _use_quantized:
# Some layouts (e.g. SVDQuant W4A4) do activation quantization
# inside their fused kernel and cannot pre-quantize a float
# tensor up-front. Skip the input wrapping for those.
layout_cls = get_layout_class(self.layout_type)
layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True)
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
if layout_quantizes_input:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))

View File

@ -37,6 +37,9 @@ except ImportError as e:
class _CKNvfp4Layout:
pass
class _CKSVDQuantW4A4Layout:
pass
def register_layout_class(name, cls):
pass
@ -55,6 +58,16 @@ if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
_CK_SVDQUANT_W4A4_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout
_CK_SVDQUANT_W4A4_AVAILABLE = True
except ImportError:
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
class _CKSVDQuantW4A4Layout:
pass
import comfy.float
# ==============================================================================
@ -162,6 +175,13 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e5m2
# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the
# kitchen-registered layout class unchanged. Comfy-side extension reserved in
# case per-layer input scales or other Comfy-specific metadata are added later.
class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
pass
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
@ -176,6 +196,8 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
if _CK_SVDQUANT_W4A4_AVAILABLE:
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -204,6 +226,14 @@ if _CK_MXFP8_AVAILABLE:
"group_size": 32,
}
if _CK_SVDQUANT_W4A4_AVAILABLE:
QUANT_ALGOS["svdquant_w4a4"] = {
"storage_t": torch.int8,
"parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"},
"comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout",
"group_size": 64,
}
# ==============================================================================
# Re-exports for backward compatibility
@ -216,6 +246,7 @@ __all__ = [
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"TensorCoreSVDQuantW4A4Layout",
"QUANT_ALGOS",
"register_layout_op",
]