mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 2322ff5bf7 into 7bbf1e8169
This commit is contained in:
commit
1e791d6557
70
comfy/ops.py
70
comfy/ops.py
@ -1063,9 +1063,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
if self.quant_format not in QUANT_ALGOS:
|
||||
raise ValueError(
|
||||
f"Quantization format '{self.quant_format}' for layer {layer_name} "
|
||||
f"is not available in this build (supported: {sorted(QUANT_ALGOS.keys())}). "
|
||||
"Update comfy_kitchen to enable it."
|
||||
)
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
self._layout_cls = get_layout_class(self.layout_type)
|
||||
self._layout_quantizes_input = getattr(self._layout_cls, "QUANTIZES_INPUT", True)
|
||||
layout_cls = self._layout_cls
|
||||
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
@ -1109,6 +1117,42 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
elif self.quant_format == "svdquant_w4a4":
|
||||
# SVDQuant W4A4: per-group weight scales + low-rank correction
|
||||
# (proj_down, proj_up) + activation smoothing (smooth_factor)
|
||||
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
proj_down = self._load_scale_param(state_dict, prefix, "proj_down", device, manually_loaded_keys)
|
||||
proj_up = self._load_scale_param(state_dict, prefix, "proj_up", device, manually_loaded_keys)
|
||||
smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys)
|
||||
act_unsigned = bool(layer_conf.get("act_unsigned", False))
|
||||
|
||||
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
|
||||
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=wscales,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
proj_down=proj_down,
|
||||
proj_up=proj_up,
|
||||
smooth_factor=smooth_factor,
|
||||
act_unsigned=act_unsigned,
|
||||
)
|
||||
elif self.quant_format == "awq_w4a16":
|
||||
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used by
|
||||
# Qwen-Image-Edit modulation linears so they stay packed
|
||||
# instead of being dequantized to bf16 at load time.
|
||||
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
||||
if wscales is None or wzeros is None:
|
||||
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
|
||||
params = layout_cls.Params(
|
||||
scale=wscales,
|
||||
zeros=wzeros,
|
||||
group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
@ -1158,6 +1202,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
if bool(getattr(getattr(self.weight, "_params", None), "act_unsigned", False)):
|
||||
quant_conf["act_unsigned"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
@ -1215,18 +1261,18 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
if getattr(self, "_layout_quantizes_input", True):
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
|
||||
|
||||
@ -20,8 +20,14 @@ try:
|
||||
else:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
# cu<13 lacks the block-scale FP4 cuBLASLt APIs but not the int4
|
||||
# MMA or fp8 paths. Kitchen's per-op FunctionConstraints already
|
||||
# gate scaled_mm_nvfp4 behind HAS_CUBLASLT, so we keep the CUDA
|
||||
# backend enabled for svdquant_w4a4 / fp8 / mxfp8 / rope.
|
||||
logging.warning(
|
||||
"cuda_version=%s < 13: NVFP4 cuBLAS path unavailable; "
|
||||
"other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.",
|
||||
".".join(map(str, cuda_version)))
|
||||
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
@ -47,6 +53,12 @@ except ImportError as e:
|
||||
class _CKNvfp4Layout:
|
||||
pass
|
||||
|
||||
class _CKSVDQuantW4A4Layout:
|
||||
pass
|
||||
|
||||
class _CKAWQW4A16Layout:
|
||||
pass
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
|
||||
@ -65,6 +77,30 @@ if not _CK_MXFP8_AVAILABLE:
|
||||
class _CKMxfp8Layout:
|
||||
pass
|
||||
|
||||
_CK_SVDQUANT_W4A4_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout
|
||||
_CK_SVDQUANT_W4A4_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
|
||||
|
||||
if not _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
class _CKSVDQuantW4A4Layout:
|
||||
pass
|
||||
|
||||
_CK_AWQ_W4A16_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
||||
_CK_AWQ_W4A16_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will not be supported.")
|
||||
|
||||
if not _CK_AWQ_W4A16_AVAILABLE:
|
||||
class _CKAWQW4A16Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@ -172,6 +208,19 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e5m2
|
||||
|
||||
|
||||
# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the
|
||||
# kitchen-registered layout class unchanged. Comfy-side extension reserved in
|
||||
# case per-layer input scales or other Comfy-specific metadata are added later.
|
||||
class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
|
||||
pass
|
||||
|
||||
|
||||
# AWQ W4A16 — pre-quantized offline modulation linears. Kitchen owns the
|
||||
# tensor subclass dispatch and gemv implementation; ComfyUI only loads params.
|
||||
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
||||
pass
|
||||
|
||||
|
||||
# Backward compatibility alias - default to E4M3
|
||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||
|
||||
@ -186,6 +235,10 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
|
||||
if _CK_AWQ_W4A16_AVAILABLE:
|
||||
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@ -214,6 +267,22 @@ if _CK_MXFP8_AVAILABLE:
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||
QUANT_ALGOS["svdquant_w4a4"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"},
|
||||
"comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout",
|
||||
"group_size": 64,
|
||||
}
|
||||
|
||||
if _CK_AWQ_W4A16_AVAILABLE:
|
||||
QUANT_ALGOS["awq_w4a16"] = {
|
||||
"storage_t": torch.int8,
|
||||
"parameters": {"weight_scale", "weight_zero"},
|
||||
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
|
||||
"group_size": 64,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
@ -226,6 +295,8 @@ __all__ = [
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
"TensorCoreFP8E5M2Layout",
|
||||
"TensorCoreNVFP4Layout",
|
||||
"TensorCoreSVDQuantW4A4Layout",
|
||||
"TensorCoreAWQW4A16Layout",
|
||||
"QUANT_ALGOS",
|
||||
"register_layout_op",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user