mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 96e5287a72 into 65045730a6
This commit is contained in:
commit
a4cb9bfc9d
73
comfy/ops.py
73
comfy/ops.py
@ -1109,6 +1109,51 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
elif self.quant_format == "svdquant_w4a4":
|
||||||
|
# SVDQuant W4A4: per-group weight scales + low-rank correction
|
||||||
|
# (proj_down, proj_up) + activation smoothing (smooth_factor)
|
||||||
|
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
|
proj_down = self._load_scale_param(state_dict, prefix, "proj_down", device, manually_loaded_keys)
|
||||||
|
proj_up = self._load_scale_param(state_dict, prefix, "proj_up", device, manually_loaded_keys)
|
||||||
|
smooth_factor = self._load_scale_param(state_dict, prefix, "smooth_factor", device, manually_loaded_keys)
|
||||||
|
act_unsigned = bool(layer_conf.get("act_unsigned", False))
|
||||||
|
|
||||||
|
# Early Qwen-Image conversion artifacts did not persist the
|
||||||
|
# fused GELU -> fc2 unsigned-activation flag. Those layers
|
||||||
|
# are the second linear in the feed-forward block.
|
||||||
|
if not act_unsigned and (
|
||||||
|
layer_name.endswith(".img_mlp.net.2") or layer_name.endswith(".txt_mlp.net.2")
|
||||||
|
):
|
||||||
|
act_unsigned = True
|
||||||
|
|
||||||
|
if any(t is None for t in (wscales, proj_down, proj_up, smooth_factor)):
|
||||||
|
raise ValueError(f"Missing SVDQuant W4A4 parameters for layer {layer_name}")
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=wscales,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
proj_down=proj_down,
|
||||||
|
proj_up=proj_up,
|
||||||
|
smooth_factor=smooth_factor,
|
||||||
|
act_unsigned=act_unsigned,
|
||||||
|
)
|
||||||
|
elif self.quant_format == "awq_w4a16":
|
||||||
|
# AWQ W4A16: int4 weight, fp16/bf16 activation. Used for
|
||||||
|
# the modulation linears (img_mod.1 / txt_mod.1) so they
|
||||||
|
# stay int4 in checkpoint + VRAM rather than getting
|
||||||
|
# dequantized to bf16 at conversion time (~10 GB saving).
|
||||||
|
wscales = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
|
wzeros = self._load_scale_param(state_dict, prefix, "weight_zero", device, manually_loaded_keys)
|
||||||
|
if wscales is None or wzeros is None:
|
||||||
|
raise ValueError(f"Missing AWQ W4A16 parameters for layer {layer_name}")
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=wscales,
|
||||||
|
zeros=wzeros,
|
||||||
|
group_size=int(layer_conf.get("group_size", qconfig.get("group_size", 64))),
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
@ -1158,6 +1203,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm_config:
|
if self._full_precision_mm_config:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
|
if bool(getattr(getattr(self.weight, "_params", None), "act_unsigned", False)):
|
||||||
|
quant_conf["act_unsigned"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
input_scale = getattr(self, 'input_scale', None)
|
input_scale = getattr(self, 'input_scale', None)
|
||||||
@ -1215,18 +1262,24 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
# Inference path (unchanged)
|
# Inference path (unchanged)
|
||||||
if _use_quantized:
|
if _use_quantized:
|
||||||
|
# Some layouts (e.g. SVDQuant W4A4) do activation quantization
|
||||||
|
# inside their fused kernel and cannot pre-quantize a float
|
||||||
|
# tensor up-front. Skip the input wrapping for those.
|
||||||
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
layout_quantizes_input = getattr(layout_cls, "QUANTIZES_INPUT", True)
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
if layout_quantizes_input:
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
|
|
||||||
# Fall back to non-quantized for non-2D tensors
|
# Fall back to non-quantized for non-2D tensors
|
||||||
if input_reshaped.ndim == 2:
|
if input_reshaped.ndim == 2:
|
||||||
reshaped_3d = input.ndim == 3
|
reshaped_3d = input.ndim == 3
|
||||||
# dtype is now implicit in the layout class
|
# dtype is now implicit in the layout class
|
||||||
scale = getattr(self, 'input_scale', None)
|
scale = getattr(self, 'input_scale', None)
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,12 @@ except ImportError as e:
|
|||||||
class _CKNvfp4Layout:
|
class _CKNvfp4Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class _CKSVDQuantW4A4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _CKAWQW4A16Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
def register_layout_class(name, cls):
|
def register_layout_class(name, cls):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -65,6 +71,26 @@ if not _CK_MXFP8_AVAILABLE:
|
|||||||
class _CKMxfp8Layout:
|
class _CKMxfp8Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
_CK_SVDQUANT_W4A4_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreSVDQuantW4A4Layout as _CKSVDQuantW4A4Layout
|
||||||
|
_CK_SVDQUANT_W4A4_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.info("comfy_kitchen does not expose SVDQuant W4A4 layout; int4 SVDQuant checkpoints will not be supported.")
|
||||||
|
class _CKSVDQuantW4A4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_CK_AWQ_W4A16_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreAWQW4A16Layout as _CKAWQW4A16Layout
|
||||||
|
_CK_AWQ_W4A16_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.info("comfy_kitchen does not expose AWQ W4A16 layout; int4 AWQ modulation checkpoints will fall back to bf16-dequantized layers.")
|
||||||
|
class _CKAWQW4A16Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -172,6 +198,21 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
|
|||||||
FP8_DTYPE = torch.float8_e5m2
|
FP8_DTYPE = torch.float8_e5m2
|
||||||
|
|
||||||
|
|
||||||
|
# SVDQuant W4A4 — pre-quantized offline (no runtime quantize), pass through the
|
||||||
|
# kitchen-registered layout class unchanged. Comfy-side extension reserved in
|
||||||
|
# case per-layer input scales or other Comfy-specific metadata are added later.
|
||||||
|
class TensorCoreSVDQuantW4A4Layout(_CKSVDQuantW4A4Layout):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# AWQ W4A16 — pre-quantized offline (no runtime quantize) via the kitchen
|
||||||
|
# eager `gemv_awq_w4a16` op. Used for modulation linears (img_mod.1 /
|
||||||
|
# txt_mod.1) on Qwen-Image-Edit and similar topologies where keeping the
|
||||||
|
# weight at int4 saves ~10 GB of VRAM vs the bf16-dequantized fallback.
|
||||||
|
class TensorCoreAWQW4A16Layout(_CKAWQW4A16Layout):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - default to E4M3
|
# Backward compatibility alias - default to E4M3
|
||||||
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
||||||
|
|
||||||
@ -186,6 +227,10 @@ register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
|||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
if _CK_MXFP8_AVAILABLE:
|
if _CK_MXFP8_AVAILABLE:
|
||||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreSVDQuantW4A4Layout", TensorCoreSVDQuantW4A4Layout)
|
||||||
|
if _CK_AWQ_W4A16_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreAWQW4A16Layout", TensorCoreAWQW4A16Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -214,6 +259,22 @@ if _CK_MXFP8_AVAILABLE:
|
|||||||
"group_size": 32,
|
"group_size": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_SVDQUANT_W4A4_AVAILABLE:
|
||||||
|
QUANT_ALGOS["svdquant_w4a4"] = {
|
||||||
|
"storage_t": torch.int8,
|
||||||
|
"parameters": {"weight_scale", "proj_down", "proj_up", "smooth_factor"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreSVDQuantW4A4Layout",
|
||||||
|
"group_size": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _CK_AWQ_W4A16_AVAILABLE:
|
||||||
|
QUANT_ALGOS["awq_w4a16"] = {
|
||||||
|
"storage_t": torch.int8,
|
||||||
|
"parameters": {"weight_scale", "weight_zero"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreAWQW4A16Layout",
|
||||||
|
"group_size": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
@ -222,10 +283,12 @@ if _CK_MXFP8_AVAILABLE:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"QuantizedTensor",
|
"QuantizedTensor",
|
||||||
"QuantizedLayout",
|
"QuantizedLayout",
|
||||||
|
"TensorCoreAWQW4A16Layout",
|
||||||
"TensorCoreFP8Layout",
|
"TensorCoreFP8Layout",
|
||||||
"TensorCoreFP8E4M3Layout",
|
"TensorCoreFP8E4M3Layout",
|
||||||
"TensorCoreFP8E5M2Layout",
|
"TensorCoreFP8E5M2Layout",
|
||||||
"TensorCoreNVFP4Layout",
|
"TensorCoreNVFP4Layout",
|
||||||
|
"TensorCoreSVDQuantW4A4Layout",
|
||||||
"QUANT_ALGOS",
|
"QUANT_ALGOS",
|
||||||
"register_layout_op",
|
"register_layout_op",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user