add svdquant int4 support, modify qwen model to support nunchaku style merged qkv

This commit is contained in:
Yu Li 2025-11-29 12:01:17 -06:00
parent a17cf1c387
commit c8794e1155
8 changed files with 1721 additions and 67 deletions

View File

@ -124,6 +124,22 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
| svdquant_int4 | int8 (packed 4-bit) | - | - | - | - |
| svdquant_nvfp4 | int8 (packed 4-bit) | - | - | - | - |
| awq_int4 | int32 (packed 4-bit) | - | - | - | - |
For SVDQuant formats, additional parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **smooth_factor**: Smoothing factors for inputs (shape: in_features)
- **smooth_factor_orig**: Original smoothing factors (shape: in_features)
- **proj_down**: Low-rank down projection (shape: in_features, rank)
- **proj_up**: Low-rank up projection (shape: out_features, rank)
- **wtscale**: Global weight scale (nvfp4 only, scalar float)
- **wcscales**: Channel-wise weight scales (nvfp4 only, shape: out_features)
For AWQ format, the following parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **wzeros**: Weight zero points (shape: in_features // group_size, out_features)
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
@ -139,9 +155,9 @@ Example:
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
}
}
}
@ -165,4 +181,137 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
## SVDQuant
SVDQuant is an advanced 4-bit quantization scheme that decomposes linear operations using low-rank factorization combined with residual quantization:
```
X*W = X * proj_down * proj_up + quantize(X) * quantize(R)
```
Where:
- `proj_down`, `proj_up`: Low-rank factorization matrices of the original weights
- `R`: Residual weights (quantized to 4-bit)
- `quantize()`: 4-bit quantization with smoothing factors
### Key Features
1. **Asymmetric Quantization**: Unlike FP8 where both weights and activations are quantized offline or use the same quantization scheme, SVDQuant:
- Quantizes weights offline with multiple parameters stored in the checkpoint
- Quantizes activations on-the-fly during forward pass using smoothing factors
2. **Two Precision Modes**:
- `svdquant_int4`: 4-bit integer quantization with group_size=64
- `svdquant_nvfp4`: 4-bit floating-point (NVIDIA FP4) with group_size=16, includes additional channel-wise scales
3. **Low-Rank Optimization**: Separates the easy-to-approximate low-rank component from the hard-to-quantize residual, improving accuracy.
### Implementation
SVDQuant requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```
The implementation uses two main operations:
- `svdq_quantize_w4a4_act_fuse_lora_cuda`: Quantizes activations and computes low-rank hidden states
- `svdq_gemm_w4a4_cuda`: Performs the quantized GEMM with low-rank residual addition
### Checkpoint Format
SVDQuant checkpoints contain the standard weight tensor (packed 4-bit residuals in int8) plus additional parameters per quantized layer:
```python
{
"layer_name.weight": tensor, # Packed 4-bit residual weights (out_features, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.smooth_factor": tensor, # Smoothing factors (in_features,)
"layer_name.smooth_factor_orig": tensor, # Original smoothing factors (in_features,)
"layer_name.proj_down": tensor, # Low-rank down projection (in_features, rank)
"layer_name.proj_up": tensor, # Low-rank up projection (out_features, rank)
# For nvfp4 only:
"layer_name.wtscale": float, # Global weight scale
"layer_name.wcscales": tensor, # Channel-wise scales (out_features,)
}
```
The quantization metadata specifies which layers use SVDQuant:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "svdquant_int4"},
"model.layers.0.mlp.down_proj": {"format": "svdquant_int4"}
}
}
}
```
## AWQ
AWQ (Activation-aware Weight Quantization) is a 4-bit weight quantization scheme that keeps activations in 16-bit precision (W4A16):
```
Y = X @ W_quantized
```
Where:
- `X`: 16-bit activations (float16/bfloat16)
- `W_quantized`: 4-bit quantized weights with per-group scales and zero points
### Key Features
1. **W4A16 Quantization**:
- Quantizes weights to 4-bit while keeping activations in 16-bit
- Uses per-group quantization with configurable group size (typically 64)
- Stores zero points for asymmetric quantization
2. **Activation-Aware**:
- Quantization is calibrated based on activation statistics
- Protects salient weights that are important for accuracy
3. **Hardware Efficient**:
- Optimized for GPU inference
- Significantly reduces memory footprint
- Increases throughput with specialized kernels
### Implementation
AWQ requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```
The implementation uses the `awq_gemv_w4a16_cuda` kernel for efficient W4A16 matrix multiplication.
### Checkpoint Format
AWQ checkpoints contain the standard weight tensor (packed 4-bit weights in int32) plus additional parameters per quantized layer:
```python
{
"layer_name.weight": tensor, # Packed 4-bit weights (out_features // 4, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.wzeros": tensor, # Zero points (in_features // group_size, out_features)
}
```
The quantization metadata specifies which layers use AWQ:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "awq_int4"},
"model.layers.0.mlp.down_proj": {"format": "awq_int4"}
}
}
}
```

View File

@ -4,7 +4,6 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
@ -12,8 +11,9 @@ import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate
@ -33,7 +33,9 @@ class FeedForward(nn.Module):
dropout: float = 0.0,
inner_dim=None,
bias: bool = True,
dtype=None, device=None, operations=None
dtype=None, device=None, operations=None,
svdquant_format=False,
**kwargs,
):
super().__init__()
if inner_dim is None:
@ -41,7 +43,7 @@ class FeedForward(nn.Module):
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations, **kwargs))
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
@ -92,7 +94,9 @@ class Attention(nn.Module):
out_context_dim: int = None,
dtype=None,
device=None,
operations=None
operations=None,
svdquant_format=False,
**kwargs,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
@ -109,21 +113,30 @@ class Attention(nn.Module):
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.svdquant_format = svdquant_format
# Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
if self.svdquant_format: # svdq merged qkv for better perf
self.to_qkv = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
if self.svdquant_format:
self.add_qkv_proj = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
else:
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout)
])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward(
@ -140,29 +153,64 @@ class Attention(nn.Module):
seq_txt = encoder_hidden_states.shape[1]
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
if self.svdquant_format:
img_qkv = self.to_qkv(hidden_states)
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
# Reshape for multi-head attention to [B, L, H, D]
img_query = img_query.unflatten(-1, (self.heads, -1)) # [B, L, H, D]
img_key = img_key.unflatten(-1, (self.heads, -1))
img_value = img_value.unflatten(-1, (self.heads, -1))
else:
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
if self.svdquant_format:
txt_qkv = self.add_qkv_proj(encoder_hidden_states)
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
# Reshape for multi-head attention to [B, L, H, D]
txt_query = txt_query.unflatten(-1, (self.heads, -1))
txt_key = txt_key.unflatten(-1, (self.heads, -1))
txt_value = txt_value.unflatten(-1, (self.heads, -1))
else:
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
if self.svdquant_format:
# Concatenate image and text streams for joint attention
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
# Apply rotary embeddings to concatenated tensors
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
# Flatten to [B, L, H*D] for attention
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(
joint_query, joint_key, joint_value, self.heads, attention_mask
)
else:
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -183,28 +231,38 @@ class QwenImageTransformerBlock(nn.Module):
eps: float = 1e-6,
dtype=None,
device=None,
operations=None
operations=None,
scale_shift: float = None,
svdquant_format=False,
**kwargs,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.svdquant_format = svdquant_format
# For svdquant, scale_shift should be 0 as the shift is fused into weights
if scale_shift is None:
scale_shift = 0.0 if self.svdquant_format else 1.0
self.scale_shift = scale_shift
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
self.attn = Attention(
query_dim=dim,
@ -216,11 +274,18 @@ class QwenImageTransformerBlock(nn.Module):
dtype=dtype,
device=device,
operations=operations,
svdquant_format=svdquant_format,
**kwargs,
)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
if self.svdquant_format:
if self.scale_shift != 0:
scale.add_(self.scale_shift)
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@ -233,21 +298,42 @@ class QwenImageTransformerBlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
# Nunchaku's mod_params layout is [B, dim*6] with different ordering
# Need to reshape from [B, dim*6] to correct layout
if self.svdquant_format:
img_mod_params = (
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
)
txt_mod_params = (
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if self.svdquant_format:
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
)
else:
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
@ -258,6 +344,8 @@ class QwenImageTransformerBlock(nn.Module):
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
@ -307,7 +395,15 @@ class QwenImageTransformer2DModel(nn.Module):
dtype=None,
device=None,
operations=None,
scale_shift: float = None,
svdquant_format=False,
**kwargs,
):
# For svdquant, scale_shift should be 0 as the shift is fused into weights
self.svdquant_format = svdquant_format
if scale_shift is None:
scale_shift = 0.0 if self.svdquant_format else 1.0
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
@ -336,7 +432,10 @@ class QwenImageTransformer2DModel(nn.Module):
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations
operations=operations,
scale_shift=scale_shift,
svdquant_format=svdquant_format,
**kwargs
)
for _ in range(num_layers)
])
@ -384,10 +483,12 @@ class QwenImageTransformer2DModel(nn.Module):
control=None,
**kwargs
):
#from safetensors import safe_open
#with safe_open("/root/nck_x.safetensors", framework="pt", device="cuda") as f:
# x = f.get_tensor("nck_x")
timestep = timesteps
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
@ -419,7 +520,10 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
if self.svdquant_format:
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
else:
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
@ -441,6 +545,9 @@ class QwenImageTransformer2DModel(nn.Module):
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:

View File

@ -623,6 +623,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
# Add SVDQuant linear support
if '{}transformer_blocks.0.attn.add_qkv_proj.weight'.format(key_prefix) in state_dict_keys:
# try import nunchaku:
try:
from nunchaku.models.linear import SVDQW4A4Linear
except ImportError:
raise ImportError(
"SVDQuant requires the nunchaku library. "
"Please follow the instructions in https://nunchaku.tech/docs/nunchaku/installation/installation.html to install nunchaku"
)
dit_config["svdquant_format"] = True
if metadata is not None and 'config' in metadata:
if 'quantization_config' in metadata:
import json
metadata_quantization_config = json.loads(metadata['quantization_config'])
if 'weight' in metadata_quantization_config:
if metadata_quantization_config["weight"]["dtype"] == "fp4_e2m1_all":
if metadata_quantization_config["weight"]["group_size"] == 16:
dit_config['precision'] = "nvfp4"
elif metadata_quantization_config["weight"]["dtype"] == "int4":
dit_config['precision'] = "int4"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

View File

@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
from comfy.quant_ops import QuantizedTensor
def run_every_op():
if torch.compiler.is_compiling():
@ -582,11 +583,17 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
if device is None and self.bias is not None:
device = self.bias.device
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
if device is None:
device = weight.device
manually_loaded_keys = [weight_key]
@ -600,27 +607,58 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
# Build layout_params - start with basic parameters
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
'is_weight': True, # Mark this as a weight tensor
}
if layout_params['scale'] is not None:
# Add group_size and precision if present in qconfig
if 'group_size' in qconfig:
layout_params['group_size'] = qconfig['group_size']
if 'precision' in qconfig:
layout_params['precision'] = qconfig['precision']
# Handle weight_scale
weight_scale_key = f"{prefix}weight_scale"
weight_scale = state_dict.pop(weight_scale_key, None)
if weight_scale is not None:
layout_params['scale'] = weight_scale
manually_loaded_keys.append(weight_scale_key)
# custom_layer_params_keys are loaded into layout_params from state_dict
if 'custom_layer_params_keys' in qconfig:
for param_name in qconfig['custom_layer_params_keys']:
param_key = f"{prefix}{param_name}"
param_value = state_dict.pop(param_key, None)
if param_value is not None:
layout_params[param_name] = param_value.to(device=device).contiguous()
manually_loaded_keys.append(param_key)
else:
logging.warning(f"Missing custom parameter {param_name} for layer {layer_name}")
# parameters are loaded into module attributes from state_dict
for param_name in qconfig["parameters"]:
if param_name in layout_params:
continue # Already loaded via custom_layer_params_keys or weight_scale
param_key = f"{prefix}{param_name}"
param_value = state_dict.pop(param_key, None)
if param_value is not None:
# For standard parameters, store as module attributes
setattr(self, param_name, torch.nn.Parameter(param_value.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
# Create the quantized weight tensor
quantized_weight = QuantizedTensor(weight.to(device=device),
self.layout_type, layout_params)
self.weight_prefix = prefix
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
quantized_weight,
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
self.weight.requires_grad = False
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -646,6 +684,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
@ -696,4 +735,4 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
return manual_cast
return manual_cast

View File

@ -46,13 +46,28 @@ def register_generic_util(torch_op):
def _get_layout_from_args(args):
def _extract_layout(obj):
if isinstance(obj, QuantizedTensor):
return obj._layout_type
# For torch.nn.Parameter wrapping QuantizedTensor, check the data attribute
if isinstance(obj, torch.nn.Parameter):
if isinstance(obj.data, QuantizedTensor):
return obj.data._layout_type
if hasattr(obj.data, "_layout_type"):
return getattr(obj.data, "_layout_type", None)
if hasattr(obj, "_layout_type"):
return getattr(obj, "_layout_type", None)
return None
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
layout = _extract_layout(arg)
if layout is not None:
return layout
if isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
layout = _extract_layout(item)
if layout is not None:
return layout
return None
@ -438,6 +453,46 @@ QUANT_ALGOS = {
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
"svdquant_int4": {
"storage_t": torch.int8, # Packed 4-bit stored in int8
"parameters": {
"wscales",
"smooth_factor",
"smooth_factor_orig",
"proj_down",
"proj_up",
},
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up"],
"comfy_tensor_layout": "SVDQuantLayout",
"group_size": 64,
"precision": "int4",
},
"svdquant_nvfp4": {
"storage_t": torch.int8, # Packed 4-bit stored in int8
"parameters": {
"wscales",
"smooth_factor",
"smooth_factor_orig",
"proj_down",
"proj_up",
"wtscale",
"wcscales",
},
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up", "wtscale", "wcscales"],
"comfy_tensor_layout": "SVDQuantLayout",
"group_size": 16,
"precision": "nvfp4",
},
"awq_int4": {
"storage_t": torch.int32, # Packed 4-bit stored in int32
"parameters": {
"wscales",
"wzeros",
},
"custom_layer_params_keys": ["wscales", "wzeros"],
"comfy_tensor_layout": "AWQQuantLayout",
"group_size": 64,
},
}
LAYOUTS = {
@ -571,3 +626,439 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)
# ==============================================================================
# SVDQuant Layout + Operation Handlers
# ==============================================================================
class SVDQuantLayout(QuantizedLayout):
"""
SVDQuant W4A4 quantization layout.
SVDQuant decomposes linear operations as:
X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
Where:
- proj_up, proj_down: Low-rank factorization of weights
- R: Residual weights (quantized to 4-bit)
- quantize(): 4-bit quantization with smoothing factors
Storage format:
For weights (is_weight=True):
- qdata: Packed quantized residual weights (out_features, in_features // 2), int8
- wscales: Weight quantization scales
- smooth_factor: Smoothing factors for inputs
- proj_down: Low-rank down projection
- proj_up: Low-rank up projection
- group_size: Quantization group size (64 for int4, 16 for nvfp4)
- precision: 'int4' or 'nvfp4'
- rank: SVD rank
- wtscale: Global weight scale (nvfp4 only)
- wcscales: Channel-wise weight scales (nvfp4 only)
- act_unsigned: Whether activations are unsigned (int4 only)
- orig_dtype: Original dtype before quantization
For activations (is_weight=False):
- qdata: Original activation tensor (not quantized yet)
- orig_dtype: Original dtype
- is_weight: False marker
"""
@classmethod
def quantize(cls, tensor, is_weight=True, **kwargs):
"""
For SVDQuant, we don't perform online quantization.
- Weights are pre-quantized offline and loaded from checkpoint
- Activations are stored as-is and quantized during forward pass
"""
orig_dtype = tensor.dtype
if is_weight:
# This shouldn't be called for weights as they're loaded pre-quantized
raise NotImplementedError(
"SVDQuant weights should be loaded pre-quantized from checkpoint, "
"not quantized on-the-fly"
)
else:
# For activations, just store the tensor as-is
# It will be quantized during the linear operation
layout_params = {
'orig_dtype': orig_dtype,
'is_weight': False
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, is_weight=True, orig_dtype=None, **kwargs):
"""
Dequantization for SVDQuant.
- Activations: return as-is (not actually quantized)
- Weights: full dequantization not supported (would need to reconstruct from SVD + residual)
"""
if not is_weight:
# Activations aren't actually quantized, just return them
return qdata.to(orig_dtype) if orig_dtype else qdata
else:
# Full weight dequantization is complex and not typically needed
# Would require: proj_down @ proj_up.T + dequantize(qweight)
raise NotImplementedError(
"Full dequantization of SVDQuant weights is not supported. "
"Use the quantized forward pass instead."
)
@classmethod
def get_plain_tensors(cls, qtensor):
"""Extract the raw tensors needed for SVDQuant computation."""
if qtensor._layout_params.get('is_weight', True):
# For weights, return all the necessary components
return {
'qweight': qtensor._qdata,
'wscales': qtensor._layout_params.get('wscales'),
'smooth_factor': qtensor._layout_params.get('smooth_factor'),
'proj_down': qtensor._layout_params.get('proj_down'),
'proj_up': qtensor._layout_params.get('proj_up'),
'group_size': qtensor._layout_params.get('group_size'),
'precision': qtensor._layout_params.get('precision', 'int4'),
'wtscale': qtensor._layout_params.get('wtscale'),
'wcscales': qtensor._layout_params.get('wcscales'),
'act_unsigned': qtensor._layout_params.get('act_unsigned', False),
}
else:
# For activations, just return the tensor
return qtensor._qdata
@register_layout_op(torch.ops.aten.addmm.default, "SVDQuantLayout")
@register_layout_op(torch.ops.aten.linear.default, "SVDQuantLayout")
def svdquant_linear(func, args, kwargs):
"""
SVDQuant linear operation handler.
Implements: X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
Handles both aten.linear and aten.addmm (which linear decomposes into).
"""
# Handle both linear and addmm calling conventions
if func == torch.ops.aten.addmm.default:
# addmm(bias, input, weight.t()) -> out
bias = args[0] if len(args) > 0 else None
input_tensor = args[1] if len(args) > 1 else None
weight = args[2] if len(args) > 2 else None
# Weight comes transposed in addmm, but SVDQuant stores it non-transposed
# So we need to transpose it back
need_transpose = True
else:
# linear(input, weight, bias) -> out
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
need_transpose = False
# Unwrap Parameter if necessary
if isinstance(weight, torch.nn.Parameter):
weight = weight.data
# Check if weight is SVDQuant quantized
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "SVDQuantLayout":
# Fallback to standard linear
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
if func == torch.ops.aten.addmm.default:
return torch.addmm(bias, input_tensor, weight)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
# Extract weight parameters
weight_params = SVDQuantLayout.get_plain_tensors(weight)
qweight = weight_params['qweight']
wscales = weight_params['wscales']
smooth_factor = weight_params['smooth_factor']
proj_down = weight_params['proj_down']
proj_up = weight_params['proj_up']
group_size = weight_params['group_size']
precision = weight_params['precision']
wtscale = weight_params['wtscale']
wcscales = weight_params['wcscales']
act_unsigned = weight_params['act_unsigned']
# Get activation tensor (dequantize if it's a QuantizedTensor)
if isinstance(input_tensor, QuantizedTensor):
if input_tensor._layout_type == "SVDQuantLayout":
x = SVDQuantLayout.get_plain_tensors(input_tensor)
else:
x = input_tensor.dequantize()
else:
x = input_tensor
# Import nunchaku operations
try:
from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda
except ImportError:
raise ImportError(
"SVDQuant requires the nunchaku library. "
"Install it with: pip install nunchaku"
)
# Handle batch dimensions
original_shape = x.shape
if len(original_shape) == 2:
batch_size, channels = original_shape
seq_len = 1
x = x.view(batch_size, seq_len, channels)
elif len(original_shape) == 3:
batch_size, seq_len, channels = original_shape
else:
raise ValueError(f"SVDQuant linear expects 2D or 3D input, got {len(original_shape)}D")
# Reshape to 2D for computation
x_2d = x.reshape(batch_size * seq_len, channels)
original_batch_size = x_2d.shape[0] # Track original size before padding
# Step 1: Quantize activations and compute low-rank hidden states
# Output: quantized_x, ascales, lora_act_out
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
x_2d,
lora_down=proj_down,
smooth=smooth_factor,
fp4=(precision == "nvfp4"),
pad_size=256
)
# Step 2: Compute quantized GEMM with low-rank residual
# Output shape: (N_padded, out_features) where N_padded may be larger due to padding
out_features = qweight.shape[0]
output = torch.empty(
quantized_x.shape[0],
out_features,
dtype=proj_up.dtype,
device=x.device
)
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=qweight,
out=output,
ascales=ascales,
wscales=wscales,
lora_act_in=lora_act_out,
lora_up=proj_up,
bias=bias,
fp4=(precision == "nvfp4"),
alpha=wtscale,
wcscales=wcscales,
act_unsigned=act_unsigned,
)
# Slice to remove padding and reshape back to original batch dimensions
output = output[:original_batch_size, :] # Remove padding
if len(original_shape) == 2:
output = output.view(batch_size, out_features)
else:
output = output.view(batch_size, seq_len, out_features)
return output
# ==============================================================================
# AWQ Layout + Operation Handlers
# ==============================================================================
class AWQQuantLayout(QuantizedLayout):
"""
AWQ W4A16 quantization layout.
AWQ (Activation-aware Weight Quantization) quantizes weights to 4-bit
while keeping activations in 16-bit precision (float16/bfloat16).
Storage format:
For weights (is_weight=True):
- qdata: Packed quantized weights (out_features // 4, in_features // 2), int32
- wscales: Weight quantization scales (in_features // group_size, out_features)
- wzeros: Weight zero points (in_features // group_size, out_features)
- group_size: Quantization group size (default 64)
- orig_dtype: Original dtype before quantization
For activations (is_weight=False):
- qdata: Original activation tensor (not quantized)
- orig_dtype: Original dtype
- is_weight: False marker
"""
@classmethod
def quantize(cls, tensor, is_weight=True, **kwargs):
"""
For AWQ, we don't perform online quantization.
- Weights are pre-quantized offline and loaded from checkpoint
- Activations remain in 16-bit precision
"""
orig_dtype = tensor.dtype
if is_weight:
# This shouldn't be called for weights as they're loaded pre-quantized
raise NotImplementedError(
"AWQ weights should be loaded pre-quantized from checkpoint, "
"not quantized on-the-fly"
)
else:
# For activations, just store the tensor as-is
layout_params = {
'orig_dtype': orig_dtype,
'is_weight': False
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, is_weight=True, orig_dtype=None, wscales=None, wzeros=None, group_size=64, **kwargs):
"""
Dequantization for AWQ.
- Activations: return as-is (not quantized)
- Weights: unpack and dequantize from 4-bit
"""
if not is_weight:
# Activations aren't quantized, just return them
return qdata.to(orig_dtype) if orig_dtype else qdata
else:
# Dequantize 4-bit weights
# qdata shape: (out_features // 4, in_features // 2), dtype int32
# Output shape should be: (out_features, in_features)
# This is a complex operation that requires unpacking 4-bit values
# For now, we'll raise an error and rely on the quantized forward pass
raise NotImplementedError(
"Full dequantization of AWQ weights is not yet supported. "
"Use the quantized forward pass instead."
)
@classmethod
def get_plain_tensors(cls, qtensor):
"""Extract the raw tensors needed for AWQ computation."""
if qtensor._layout_params.get('is_weight', True):
# For weights, return all the necessary components
return {
'qweight': qtensor._qdata,
'wscales': qtensor._layout_params.get('wscales'),
'wzeros': qtensor._layout_params.get('wzeros'),
'group_size': qtensor._layout_params.get('group_size', 64),
}
else:
# For activations, just return the tensor
return qtensor._qdata
@register_layout_op(torch.ops.aten.addmm.default, "AWQQuantLayout")
@register_layout_op(torch.ops.aten.linear.default, "AWQQuantLayout")
def awq_linear(func, args, kwargs):
"""
AWQ linear operation handler.
Implements W4A16 quantized linear using AWQ format.
Handles both aten.linear and aten.addmm (which linear decomposes into).
"""
# Handle both linear and addmm calling conventions
if func == torch.ops.aten.addmm.default:
# addmm(bias, input, weight.t()) -> out
bias = args[0] if len(args) > 0 else None
input_tensor = args[1] if len(args) > 1 else None
weight = args[2] if len(args) > 2 else None
else:
# linear(input, weight, bias) -> out
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
# Unwrap Parameter if necessary
if isinstance(weight, torch.nn.Parameter):
weight = weight.data
# Check if weight is AWQ quantized
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "AWQQuantLayout":
# Fallback to standard linear
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
if func == torch.ops.aten.addmm.default:
return torch.addmm(bias, input_tensor, weight)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
# Extract weight parameters
weight_params = AWQQuantLayout.get_plain_tensors(weight)
qweight = weight_params['qweight']
wscales = weight_params['wscales']
wzeros = weight_params['wzeros']
group_size = weight_params['group_size']
# Get activation tensor (dequantize if it's a QuantizedTensor)
if isinstance(input_tensor, QuantizedTensor):
if input_tensor._layout_type == "AWQQuantLayout":
x = AWQQuantLayout.get_plain_tensors(input_tensor)
else:
x = input_tensor.dequantize()
else:
x = input_tensor
# Import nunchaku AWQ operation
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
raise ImportError(
"AWQ requires the nunchaku library. "
"Install it with: pip install nunchaku"
)
# Calculate output dimensions from packed weight shape
# qweight shape: (out_features // 4, in_features // 2)
out_features = qweight.shape[0] * 4
in_features = qweight.shape[1] * 2
# Handle batch dimensions - preserve original shape
# Important: nunchaku expects 2D input only, so we reshape 3D to 2D
original_shape = x.shape
if len(original_shape) == 2:
# (batch_size, in_features)
batch_size = original_shape[0]
x_2d = x
#elif len(original_shape) == 3:
# # (batch_size, seq_len, in_features) -> (batch_size * seq_len, in_features)
# batch_size, seq_len, _ = original_shape
# x_2d = x.reshape(batch_size * seq_len, in_features)
else:
raise ValueError(f"AWQ linear expects 2D or 3D input, got {len(original_shape)}D")
# Ensure input is contiguous (required by CUDA kernel)
# Only create a contiguous copy if absolutely necessary
#if not x_2d.is_contiguous():
# x_2d = x_2d.contiguous()
output = awq_gemv_w4a16_cuda(
in_feats=x_2d,
kernel=qweight,
scaling_factors=wscales,
zeros=wzeros,
m=x_2d.shape[0],
n=out_features,
k=in_features,
group_size=group_size,
)
# Add bias if present
if bias is not None:
view_shape = [1] * (output.ndim - 1) + [-1]
output = output + bias.view(view_shape)
# Reshape back to original batch dimensions
#if len(original_shape) == 3:
# output = output.view(batch_size, seq_len, out_features)
return output
LAYOUTS["SVDQuantLayout"] = SVDQuantLayout
LAYOUTS["AWQQuantLayout"] = AWQQuantLayout

377
comfy/svdquant_converter.py Normal file
View File

@ -0,0 +1,377 @@
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
from safetensors import safe_open
from safetensors.torch import save_file
# Note: Fused layer splitting is no longer used
@dataclass
class ConvertedState:
tensors: Dict[str, torch.Tensor]
quant_layers: Dict[str, str]
def _is_svd_prefix(keys: set[str], prefix: str) -> bool:
return (
f"{prefix}.qweight" in keys
and f"{prefix}.smooth_factor" in keys
and f"{prefix}.proj_down" in keys
and f"{prefix}.proj_up" in keys
)
def _is_awq_prefix(keys: set[str], prefix: str) -> bool:
return (
f"{prefix}.qweight" in keys
and f"{prefix}.wscales" in keys
and f"{prefix}.wzeros" in keys
and f"{prefix}.smooth_factor" not in keys # Distinguish from SVDQuant
)
def _detect_svd_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
prefixes = set()
keys = set(state_dict.keys())
for key in keys:
if not key.endswith(".qweight"):
continue
prefix = key[: -len(".qweight")]
if _is_svd_prefix(keys, prefix):
prefixes.add(prefix)
return sorted(prefixes)
def _detect_awq_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
prefixes = set()
keys = set(state_dict.keys())
for key in keys:
if not key.endswith(".qweight"):
continue
prefix = key[: -len(".qweight")]
if _is_awq_prefix(keys, prefix):
prefixes.add(prefix)
return sorted(prefixes)
def _detect_format(wscales: torch.Tensor) -> str:
if wscales.dtype == torch.float8_e4m3fn:
return "svdquant_nvfp4"
return "svdquant_int4"
class _SVDQuantConverter:
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.src = dict(state_dict)
self.dst: Dict[str, torch.Tensor] = {}
self.quant_layers: Dict[str, str] = {}
def convert(self) -> ConvertedState:
prefixes = _detect_svd_prefixes(self.src)
for prefix in prefixes:
self._convert_single(prefix)
for key, tensor in self.src.items():
if key not in self.dst:
self.dst[key] = tensor
return ConvertedState(self.dst, self.quant_layers)
def _pop_tensor(self, key: str) -> torch.Tensor:
try:
return self.src.pop(key)
except KeyError as exc:
raise KeyError(f"Missing key '{key}' in SVDQuant checkpoint") from exc
def _pop_optional(self, key: str) -> torch.Tensor | None:
return self.src.pop(key, None)
def _convert_single(self, prefix: str) -> None:
# Ensure all tensors are contiguous to avoid CUDA alignment issues
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
wscales = self._pop_tensor(f"{prefix}.wscales").contiguous()
self.dst[f"{prefix}.wscales"] = wscales
format_name = _detect_format(wscales)
self.dst[f"{prefix}.smooth_factor"] = self._pop_tensor(f"{prefix}.smooth_factor").contiguous()
self.dst[f"{prefix}.smooth_factor_orig"] = self._pop_tensor(
f"{prefix}.smooth_factor_orig"
).contiguous()
self.dst[f"{prefix}.proj_down"] = self._pop_tensor(f"{prefix}.proj_down").contiguous()
self.dst[f"{prefix}.proj_up"] = self._pop_tensor(f"{prefix}.proj_up").contiguous()
bias = self._pop_optional(f"{prefix}.bias")
if bias is not None:
self.dst[f"{prefix}.bias"] = bias.contiguous()
wtscale = self._pop_optional(f"{prefix}.wtscale")
if wtscale is not None:
self.dst[f"{prefix}.wtscale"] = wtscale.contiguous() if isinstance(wtscale, torch.Tensor) else wtscale
wcscales = self._pop_optional(f"{prefix}.wcscales")
if wcscales is not None:
self.dst[f"{prefix}.wcscales"] = wcscales.contiguous()
self.quant_layers[prefix] = format_name
class _AWQConverter:
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.src = dict(state_dict)
self.dst: Dict[str, torch.Tensor] = {}
self.quant_layers: Dict[str, str] = {}
def convert(self) -> ConvertedState:
prefixes = _detect_awq_prefixes(self.src)
for prefix in prefixes:
self._convert_single(prefix)
for key, tensor in self.src.items():
if key not in self.dst:
self.dst[key] = tensor
return ConvertedState(self.dst, self.quant_layers)
def _pop_tensor(self, key: str) -> torch.Tensor:
try:
return self.src.pop(key)
except KeyError as exc:
raise KeyError(f"Missing key '{key}' in AWQ checkpoint") from exc
def _pop_optional(self, key: str) -> torch.Tensor | None:
return self.src.pop(key, None)
def _convert_single(self, prefix: str) -> None:
# Ensure all tensors are contiguous to avoid CUDA alignment issues
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
self.dst[f"{prefix}.wscales"] = self._pop_tensor(f"{prefix}.wscales").contiguous()
self.dst[f"{prefix}.wzeros"] = self._pop_tensor(f"{prefix}.wzeros").contiguous()
bias = self._pop_optional(f"{prefix}.bias")
if bias is not None:
self.dst[f"{prefix}.bias"] = bias.contiguous()
self.quant_layers[prefix] = "awq_int4"
def convert_svdquant_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
return _SVDQuantConverter(state_dict).convert()
def convert_awq_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
return _AWQConverter(state_dict).convert()
def detect_quantization_formats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
"""
Detect quantization formats present in a state dict.
Parameters
----------
state_dict : Dict[str, torch.Tensor]
State dictionary to analyze
Returns
-------
Dict[str, List[str]]
Dictionary mapping format names to lists of layer prefixes
Example: {
"svdquant_int4": ["layer1.attn.qkv", "layer2.mlp.up"],
"svdquant_nvfp4": ["layer3.attn.qkv"],
"awq_int4": ["layer1.mlp.down", "layer4.attn.qkv"]
}
"""
result = {}
# Detect SVDQuant layers
svd_prefixes = _detect_svd_prefixes(state_dict)
if svd_prefixes:
# Determine if int4 or nvfp4 based on wscales dtype
for prefix in svd_prefixes:
wscales_key = f"{prefix}.wscales"
if wscales_key in state_dict:
format_name = _detect_format(state_dict[wscales_key])
if format_name not in result:
result[format_name] = []
result[format_name].append(prefix)
# Detect AWQ layers
awq_prefixes = _detect_awq_prefixes(state_dict)
if awq_prefixes:
result["awq_int4"] = awq_prefixes
return result
def convert_awq_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
) -> Tuple[int, Dict[str, str]]:
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
converted = convert_awq_state_dict(tensors)
# Convert layer format dict to expected metadata format
# From: {"layer": "awq_int4"}
# To: {"layer": {"format": "awq_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers
def convert_svdquant_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
) -> Tuple[int, Dict[str, str]]:
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
converted = convert_svdquant_state_dict(tensors)
# Convert layer format dict to expected metadata format
# From: {"layer": "svdquant_int4"}
# To: {"layer": {"format": "svdquant_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
metadata["model_class"] = "QwenImageTransformer2DModel"
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers
def convert_quantized_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
quant_format: str = "auto",
) -> Tuple[int, Dict[str, str]]:
"""
Auto-detect and convert quantized checkpoint to ComfyUI format.
Supports mixed-format models where some layers are SVDQuant and others are AWQ.
Each layer is independently detected and converted to the appropriate format.
Parameters
----------
input_path : str
Path to input checkpoint file
output_path : str
Path to output checkpoint file
format_version : str, optional
Quantization metadata format version (default: "1.0")
quant_format : str, optional
Quantization format: "auto", "svdquant", or "awq" (default: "auto")
Returns
-------
Tuple[int, Dict[str, str]]
Number of quantized layers and mapping of layer names to formats
"""
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
# Auto-detect format if needed
if quant_format == "auto":
svd_prefixes = _detect_svd_prefixes(tensors)
awq_prefixes = _detect_awq_prefixes(tensors)
if svd_prefixes and awq_prefixes:
# Mixed format - partition tensors by format and convert separately
# Build sets of all quantized prefixes
all_svd_prefixes = set(svd_prefixes)
all_awq_prefixes = set(awq_prefixes)
# Helper to check if a key belongs to a specific quantized layer
def belongs_to_prefix(key, prefix):
"""Check if key belongs to a specific layer prefix."""
return key == prefix or key.startswith(f"{prefix}.")
def is_svd_key(key):
"""Check if key belongs to any SVDQuant layer."""
return any(belongs_to_prefix(key, prefix) for prefix in all_svd_prefixes)
def is_awq_key(key):
"""Check if key belongs to any AWQ layer."""
return any(belongs_to_prefix(key, prefix) for prefix in all_awq_prefixes)
# Partition tensors by format
svd_tensors = {}
awq_tensors = {}
other_tensors = {}
for key, tensor in tensors.items():
if is_svd_key(key):
svd_tensors[key] = tensor
elif is_awq_key(key):
awq_tensors[key] = tensor
else:
other_tensors[key] = tensor
# Convert each format separately with only its relevant tensors
svd_converted = _SVDQuantConverter(svd_tensors).convert()
awq_converted = _AWQConverter(awq_tensors).convert()
# Merge results - each converter only has its own layer tensors
converted_tensors = {}
# Add SVDQuant converted tensors
converted_tensors.update(svd_converted.tensors)
# Add AWQ converted tensors
converted_tensors.update(awq_converted.tensors)
# Add non-quantized tensors
converted_tensors.update(other_tensors)
# Merge quantization layer metadata
quant_layers = {}
quant_layers.update(svd_converted.quant_layers)
quant_layers.update(awq_converted.quant_layers)
converted = ConvertedState(converted_tensors, quant_layers)
elif svd_prefixes:
converted = convert_svdquant_state_dict(tensors)
elif awq_prefixes:
converted = convert_awq_state_dict(tensors)
else:
raise ValueError("No quantized layers detected in checkpoint")
elif quant_format == "svdquant":
converted = convert_svdquant_state_dict(tensors)
elif quant_format == "awq":
converted = convert_awq_state_dict(tensors)
else:
raise ValueError(f"Unknown quantization format: {quant_format}")
# Convert layer format dict to expected metadata format
# From: {"layer": "awq_int4"}
# To: {"layer": {"format": "awq_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
metadata["model_class"] = "QwenImageTransformer2DModel"
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""
Convert quantized checkpoints (SVDQuant, AWQ, or mixed) into the ComfyUI quantization format.
"""
import argparse
from pathlib import Path
from safetensors import safe_open
from comfy.svdquant_converter import (
convert_quantized_file,
convert_svdquant_file,
convert_awq_file,
detect_quantization_formats,
)
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Convert quantized .safetensors files (SVDQuant, AWQ, or mixed) "
"into the ComfyUI format with per-layer metadata for MixedPrecisionOps."
)
parser.add_argument("input", type=Path, help="Path to the source quantized .safetensors file.")
parser.add_argument(
"-o",
"--output",
type=Path,
help="Destination path for the converted checkpoint. "
"Defaults to <input_name>_comfy.safetensors in the same directory.",
)
parser.add_argument(
"--format-version",
default="1.0",
help="Format version to store inside _quantization_metadata (default: 1.0).",
)
parser.add_argument(
"--format",
choices=["auto", "svdquant", "awq"],
default="auto",
help="Quantization format (default: auto-detect).",
)
parser.add_argument(
"--detect-only",
action="store_true",
help="Only detect and report quantization formats without converting.",
)
return parser
def main() -> None:
parser = _build_parser()
args = parser.parse_args()
input_path = args.input.expanduser().resolve()
# Detect formats if requested
if args.detect_only:
print(f"[Quantization Detector] Analyzing: {input_path}")
with safe_open(str(input_path), framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
formats = detect_quantization_formats(tensors)
if not formats:
print("[Quantization Detector] No quantized layers detected.")
return
print(f"[Quantization Detector] Detected formats:")
total_layers = 0
for format_name, layer_prefixes in sorted(formats.items()):
print(f"\n {format_name}: {len(layer_prefixes)} layers")
for prefix in sorted(layer_prefixes)[:5]: # Show first 5
print(f" - {prefix}")
if len(layer_prefixes) > 5:
print(f" ... and {len(layer_prefixes) - 5} more")
total_layers += len(layer_prefixes)
print(f"\n[Quantization Detector] Total: {total_layers} quantized layers")
print(f"[Quantization Detector] Use without --detect-only to convert.")
return
# Convert checkpoint
if args.output is None:
output_path = input_path.with_name(f"{input_path.stem}_comfy.safetensors")
else:
output_path = args.output.expanduser().resolve()
layer_count, quant_layers = convert_quantized_file(
str(input_path),
str(output_path),
format_version=args.format_version,
quant_format=args.format,
)
# Group layers by format for display
format_groups = {}
for layer_name, fmt in quant_layers.items():
if fmt not in format_groups:
format_groups[fmt] = []
format_groups[fmt].append(layer_name)
print(f"[Quantization Converter] Converted {layer_count} layers.")
print(f"[Quantization Converter] Output saved to: {output_path}")
print(f"\n[Quantization Converter] Quantized layers by format:")
for fmt, layers in sorted(format_groups.items()):
print(f"\n {fmt}: {len(layers)} layers")
for layer_name in sorted(layers)[:5]: # Show first 5
print(f" - {layer_name}")
if len(layers) > 5:
print(f" ... and {len(layers) - 5} more")
if __name__ == "__main__":
main()

View File

@ -1,7 +1,10 @@
import unittest
import torch
import sys
import os
import sys
import unittest
from pathlib import Path
import torch
from safetensors.torch import load_file
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@ -13,7 +16,9 @@ from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout
from comfy.ops import mixed_precision_ops
from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict
class TestQuantizedTensor(unittest.TestCase):
@ -156,6 +161,199 @@ class TestTensorCoreFP8Layout(unittest.TestCase):
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestAWQQuantLayout(unittest.TestCase):
"""Test the AWQQuantLayout implementation"""
def test_awq_layout_creation(self):
"""Test creating an AWQ quantized tensor"""
# AWQ uses pre-quantized weights loaded from checkpoints
# Create dummy AWQ quantized weights
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int32)
self.assertEqual(qt._layout_type, "AWQQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
def test_awq_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for AWQ"""
# AWQ doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
AWQQuantLayout.quantize(float_tensor, is_weight=True)
def test_awq_get_plain_tensors(self):
"""Test extracting plain tensors from AWQ quantized tensor"""
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
plain_tensors = AWQQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('wzeros', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros))
class TestSVDQuantLayout(unittest.TestCase):
"""Test the SVDQuantLayout implementation"""
def test_svdquant_layout_creation(self):
"""Test creating an SVDQuant quantized tensor"""
# SVDQuant uses pre-quantized weights loaded from checkpoints
out_features, in_features = 256, 128
rank = 32
group_size = 64
precision = "int4"
# Create dummy SVDQuant quantized weights (int8 range is -128 to 127)
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': precision,
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int8)
self.assertEqual(qt._layout_type, "SVDQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
self.assertEqual(qt._layout_params['precision'], precision)
def test_svdquant_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for SVDQuant"""
# SVDQuant doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.quantize(float_tensor, is_weight=True)
def test_svdquant_dequantize_not_supported(self):
"""Test that weight dequantization raises NotImplementedError for SVDQuant"""
# Full weight dequantization is not supported (complex operation)
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.dequantize(
qweight,
is_weight=True,
wscales=wscales,
smooth_factor=smooth_factor,
proj_down=proj_down,
proj_up=proj_up,
group_size=group_size,
precision="int4",
orig_dtype=torch.bfloat16
)
def test_svdquant_get_plain_tensors(self):
"""Test extracting plain tensors from SVDQuant quantized tensor"""
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': "int4",
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
plain_tensors = SVDQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('smooth_factor', plain_tensors)
self.assertIn('proj_down', plain_tensors)
self.assertIn('proj_up', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertIn('precision', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor))
self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down))
self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
@ -186,5 +384,158 @@ class TestFallbackMechanism(unittest.TestCase):
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
class TestAWQConversion(unittest.TestCase):
"""Test AWQ checkpoint conversion"""
def test_awq_single_layer_conversion(self):
"""Test converting a single AWQ layer"""
in_features, out_features = 128, 256
group_size = 64
# Create AWQ checkpoint format
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.bias": torch.randn(out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check that qweight was renamed to weight
self.assertIn("layer.weight", converted.tensors)
self.assertNotIn("layer.qweight", converted.tensors)
# Check other parameters preserved
self.assertIn("layer.wscales", converted.tensors)
self.assertIn("layer.wzeros", converted.tensors)
self.assertIn("layer.bias", converted.tensors)
# Check quantization metadata
self.assertIn("layer", converted.quant_layers)
self.assertEqual(converted.quant_layers["layer"], "awq_int4")
def test_awq_tensor_shapes(self):
"""Test that converted AWQ tensors have correct shapes"""
in_features, out_features = 3072, 18432
group_size = 64
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check qweight shape (packed 4-bit)
qweight = converted.tensors["layer.weight"]
self.assertEqual(qweight.shape, (out_features // 4, in_features // 2))
self.assertEqual(qweight.dtype, torch.int32)
# Check wscales shape
wscales = converted.tensors["layer.wscales"]
self.assertEqual(wscales.shape, (in_features // group_size, out_features))
self.assertEqual(wscales.dtype, torch.bfloat16)
# Check wzeros shape
wzeros = converted.tensors["layer.wzeros"]
self.assertEqual(wzeros.shape, (in_features // group_size, out_features))
self.assertEqual(wzeros.dtype, torch.bfloat16)
class TestAWQLinearOperation(unittest.TestCase):
"""Test AWQ linear operations with actual nunchaku kernels"""
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_basic(self):
"""Test basic AWQ linear operation by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
bias = torch.randn(out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {})
# Check output shape and dtype
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_2d_input(self):
"""Test AWQ linear with 2D input (batch, features) by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create 2D input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {})
# Check output shape
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
if __name__ == "__main__":
unittest.main()
unittest.main()