add svdquant int4 support, modify qwen model to support nunchaku style merged qkv

This commit is contained in:
Yu Li 2025-11-29 12:01:17 -06:00
parent a17cf1c387
commit c8794e1155
8 changed files with 1721 additions and 67 deletions

View File

@ -124,6 +124,22 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | | Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------| |--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | | float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
| svdquant_int4 | int8 (packed 4-bit) | - | - | - | - |
| svdquant_nvfp4 | int8 (packed 4-bit) | - | - | - | - |
| awq_int4 | int32 (packed 4-bit) | - | - | - | - |
For SVDQuant formats, additional parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **smooth_factor**: Smoothing factors for inputs (shape: in_features)
- **smooth_factor_orig**: Original smoothing factors (shape: in_features)
- **proj_down**: Low-rank down projection (shape: in_features, rank)
- **proj_up**: Low-rank up projection (shape: out_features, rank)
- **wtscale**: Global weight scale (nvfp4 only, scalar float)
- **wcscales**: Channel-wise weight scales (nvfp4 only, shape: out_features)
For AWQ format, the following parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **wzeros**: Weight zero points (shape: in_features // group_size, out_features)
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
@ -139,9 +155,9 @@ Example:
"_quantization_metadata": { "_quantization_metadata": {
"format_version": "1.0", "format_version": "1.0",
"layers": { "layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn", "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": "float8_e4m3fn", "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": "float8_e4m3fn" "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
} }
} }
} }
@ -165,4 +181,137 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics 3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
## SVDQuant
SVDQuant is an advanced 4-bit quantization scheme that decomposes linear operations using low-rank factorization combined with residual quantization:
```
X*W = X * proj_down * proj_up + quantize(X) * quantize(R)
```
Where:
- `proj_down`, `proj_up`: Low-rank factorization matrices of the original weights
- `R`: Residual weights (quantized to 4-bit)
- `quantize()`: 4-bit quantization with smoothing factors
### Key Features
1. **Asymmetric Quantization**: Unlike FP8 where both weights and activations are quantized offline or use the same quantization scheme, SVDQuant:
- Quantizes weights offline with multiple parameters stored in the checkpoint
- Quantizes activations on-the-fly during forward pass using smoothing factors
2. **Two Precision Modes**:
- `svdquant_int4`: 4-bit integer quantization with group_size=64
- `svdquant_nvfp4`: 4-bit floating-point (NVIDIA FP4) with group_size=16, includes additional channel-wise scales
3. **Low-Rank Optimization**: Separates the easy-to-approximate low-rank component from the hard-to-quantize residual, improving accuracy.
### Implementation
SVDQuant requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```
The implementation uses two main operations:
- `svdq_quantize_w4a4_act_fuse_lora_cuda`: Quantizes activations and computes low-rank hidden states
- `svdq_gemm_w4a4_cuda`: Performs the quantized GEMM with low-rank residual addition
### Checkpoint Format
SVDQuant checkpoints contain the standard weight tensor (packed 4-bit residuals in int8) plus additional parameters per quantized layer:
```python
{
"layer_name.weight": tensor, # Packed 4-bit residual weights (out_features, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.smooth_factor": tensor, # Smoothing factors (in_features,)
"layer_name.smooth_factor_orig": tensor, # Original smoothing factors (in_features,)
"layer_name.proj_down": tensor, # Low-rank down projection (in_features, rank)
"layer_name.proj_up": tensor, # Low-rank up projection (out_features, rank)
# For nvfp4 only:
"layer_name.wtscale": float, # Global weight scale
"layer_name.wcscales": tensor, # Channel-wise scales (out_features,)
}
```
The quantization metadata specifies which layers use SVDQuant:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "svdquant_int4"},
"model.layers.0.mlp.down_proj": {"format": "svdquant_int4"}
}
}
}
```
## AWQ
AWQ (Activation-aware Weight Quantization) is a 4-bit weight quantization scheme that keeps activations in 16-bit precision (W4A16):
```
Y = X @ W_quantized
```
Where:
- `X`: 16-bit activations (float16/bfloat16)
- `W_quantized`: 4-bit quantized weights with per-group scales and zero points
### Key Features
1. **W4A16 Quantization**:
- Quantizes weights to 4-bit while keeping activations in 16-bit
- Uses per-group quantization with configurable group size (typically 64)
- Stores zero points for asymmetric quantization
2. **Activation-Aware**:
- Quantization is calibrated based on activation statistics
- Protects salient weights that are important for accuracy
3. **Hardware Efficient**:
- Optimized for GPU inference
- Significantly reduces memory footprint
- Increases throughput with specialized kernels
### Implementation
AWQ requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```
The implementation uses the `awq_gemv_w4a16_cuda` kernel for efficient W4A16 matrix multiplication.
### Checkpoint Format
AWQ checkpoints contain the standard weight tensor (packed 4-bit weights in int32) plus additional parameters per quantized layer:
```python
{
"layer_name.weight": tensor, # Packed 4-bit weights (out_features // 4, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.wzeros": tensor, # Zero points (in_features // group_size, out_features)
}
```
The quantization metadata specifies which layers use AWQ:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "awq_int4"},
"model.layers.0.mlp.down_proj": {"format": "awq_int4"}
}
}
}
```

View File

@ -4,7 +4,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from einops import repeat from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
@ -12,8 +11,9 @@ import comfy.ldm.common_dit
import comfy.patcher_extension import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module): class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None, **kwargs):
super().__init__() super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device) self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
self.approximate = approximate self.approximate = approximate
@ -33,7 +33,9 @@ class FeedForward(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
inner_dim=None, inner_dim=None,
bias: bool = True, bias: bool = True,
dtype=None, device=None, operations=None dtype=None, device=None, operations=None,
svdquant_format=False,
**kwargs,
): ):
super().__init__() super().__init__()
if inner_dim is None: if inner_dim is None:
@ -41,7 +43,7 @@ class FeedForward(nn.Module):
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([]) self.net = nn.ModuleList([])
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations)) self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations, **kwargs))
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device)) self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
@ -92,7 +94,9 @@ class Attention(nn.Module):
out_context_dim: int = None, out_context_dim: int = None,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None,
svdquant_format=False,
**kwargs,
): ):
super().__init__() super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads
@ -109,21 +113,30 @@ class Attention(nn.Module):
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.svdquant_format = svdquant_format
# Image stream projections # Image stream projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) if self.svdquant_format: # svdq merged qkv for better perf
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) self.to_qkv = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) else:
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text stream projections # Text stream projections
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) if self.svdquant_format:
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) self.add_qkv_proj = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) else:
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Output projections # Output projections
self.to_out = nn.ModuleList([ self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device), operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
]) ])
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device) self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
def forward( def forward(
@ -140,29 +153,64 @@ class Attention(nn.Module):
seq_txt = encoder_hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1]
# Project and reshape to BHND format (batch, heads, seq, dim) # Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() if self.svdquant_format:
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_qkv = self.to_qkv(hidden_states)
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
# Reshape for multi-head attention to [B, L, H, D]
img_query = img_query.unflatten(-1, (self.heads, -1)) # [B, L, H, D]
img_key = img_key.unflatten(-1, (self.heads, -1))
img_value = img_value.unflatten(-1, (self.heads, -1))
else:
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
if self.svdquant_format:
txt_qkv = self.add_qkv_proj(encoder_hidden_states)
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
# Reshape for multi-head attention to [B, L, H, D]
txt_query = txt_query.unflatten(-1, (self.heads, -1))
txt_key = txt_key.unflatten(-1, (self.heads, -1))
txt_value = txt_value.unflatten(-1, (self.heads, -1))
else:
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query) img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key) img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query) txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key) txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=2) if self.svdquant_format:
joint_key = torch.cat([txt_key, img_key], dim=2) # Concatenate image and text streams for joint attention
joint_value = torch.cat([txt_value, img_value], dim=2) joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rope1(joint_query, image_rotary_emb) # Apply rotary embeddings to concatenated tensors
joint_key = apply_rope1(joint_key, image_rotary_emb) joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, # Flatten to [B, L, H*D] for attention
attention_mask, transformer_options=transformer_options, joint_query = joint_query.flatten(start_dim=2)
skip_reshape=True) joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(
joint_query, joint_key, joint_value, self.heads, attention_mask
)
else:
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :] txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -183,28 +231,38 @@ class QwenImageTransformerBlock(nn.Module):
eps: float = 1e-6, eps: float = 1e-6,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None,
scale_shift: float = None,
svdquant_format=False,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
self.svdquant_format = svdquant_format
# For svdquant, scale_shift should be 0 as the shift is fused into weights
if scale_shift is None:
scale_shift = 0.0 if self.svdquant_format else 1.0
self.scale_shift = scale_shift
self.img_mod = nn.Sequential( self.img_mod = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
) )
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
self.txt_mod = nn.Sequential( self.txt_mod = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
) )
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
self.attn = Attention( self.attn = Attention(
query_dim=dim, query_dim=dim,
@ -216,11 +274,18 @@ class QwenImageTransformerBlock(nn.Module):
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations, operations=operations,
svdquant_format=svdquant_format,
**kwargs,
) )
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) if self.svdquant_format:
if self.scale_shift != 0:
scale.add_(self.scale_shift)
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward( def forward(
self, self,
@ -233,21 +298,42 @@ class QwenImageTransformerBlock(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb) txt_mod_params = self.txt_mod(temb)
# Nunchaku's mod_params layout is [B, dim*6] with different ordering
# Need to reshape from [B, dim*6] to correct layout
if self.svdquant_format:
img_mod_params = (
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
)
txt_mod_params = (
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1 del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1 del txt_mod1
img_attn_output, txt_attn_output = self.attn( if self.svdquant_format:
hidden_states=img_modulated, img_attn_output, txt_attn_output = self.attn(
encoder_hidden_states=txt_modulated, hidden_states=img_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states=txt_modulated,
image_rotary_emb=image_rotary_emb, encoder_hidden_states_mask=encoder_hidden_states_mask,
transformer_options=transformer_options, image_rotary_emb=image_rotary_emb,
) )
else:
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated del img_modulated
del txt_modulated del txt_modulated
@ -258,6 +344,8 @@ class QwenImageTransformerBlock(nn.Module):
del img_gate1 del img_gate1
del txt_gate1 del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
@ -307,7 +395,15 @@ class QwenImageTransformer2DModel(nn.Module):
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
scale_shift: float = None,
svdquant_format=False,
**kwargs,
): ):
# For svdquant, scale_shift should be 0 as the shift is fused into weights
self.svdquant_format = svdquant_format
if scale_shift is None:
scale_shift = 0.0 if self.svdquant_format else 1.0
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.patch_size = patch_size self.patch_size = patch_size
@ -336,7 +432,10 @@ class QwenImageTransformer2DModel(nn.Module):
attention_head_dim=attention_head_dim, attention_head_dim=attention_head_dim,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations,
scale_shift=scale_shift,
svdquant_format=svdquant_format,
**kwargs
) )
for _ in range(num_layers) for _ in range(num_layers)
]) ])
@ -384,10 +483,12 @@ class QwenImageTransformer2DModel(nn.Module):
control=None, control=None,
**kwargs **kwargs
): ):
#from safetensors import safe_open
#with safe_open("/root/nck_x.safetensors", framework="pt", device="cuda") as f:
# x = f.get_tensor("nck_x")
timestep = timesteps timestep = timesteps
encoder_hidden_states = context encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask encoder_hidden_states_mask = attention_mask
hidden_states, img_ids, orig_shape = self.process_img(x) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1] num_embeds = hidden_states.shape[1]
@ -419,7 +520,10 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() if self.svdquant_format:
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
else:
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
@ -441,6 +545,9 @@ class QwenImageTransformer2DModel(nn.Module):
transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:

View File

@ -623,6 +623,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "qwen_image" dit_config["image_model"] = "qwen_image"
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
# Add SVDQuant linear support
if '{}transformer_blocks.0.attn.add_qkv_proj.weight'.format(key_prefix) in state_dict_keys:
# try import nunchaku:
try:
from nunchaku.models.linear import SVDQW4A4Linear
except ImportError:
raise ImportError(
"SVDQuant requires the nunchaku library. "
"Please follow the instructions in https://nunchaku.tech/docs/nunchaku/installation/installation.html to install nunchaku"
)
dit_config["svdquant_format"] = True
if metadata is not None and 'config' in metadata:
if 'quantization_config' in metadata:
import json
metadata_quantization_config = json.loads(metadata['quantization_config'])
if 'weight' in metadata_quantization_config:
if metadata_quantization_config["weight"]["dtype"] == "fp4_e2m1_all":
if metadata_quantization_config["weight"]["group_size"] == 16:
dit_config['precision'] = "nvfp4"
elif metadata_quantization_config["weight"]["dtype"] == "int4":
dit_config['precision'] = "int4"
return dit_config return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:

View File

@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib import contextlib
from comfy.quant_ops import QuantizedTensor
def run_every_op(): def run_every_op():
if torch.compiler.is_compiling(): if torch.compiler.is_compiling():
@ -582,11 +583,17 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"] device = self.factory_kwargs["device"]
if device is None and self.bias is not None:
device = self.bias.device
layer_name = prefix.rstrip('.') layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight" weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None) weight = state_dict.pop(weight_key, None)
if weight is None: if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}") raise ValueError(f"Missing weight for layer {layer_name}")
if device is None:
device = weight.device
manually_loaded_keys = [weight_key] manually_loaded_keys = [weight_key]
@ -600,27 +607,58 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
qconfig = QUANT_ALGOS[quant_format] qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"] self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale" # Build layout_params - start with basic parameters
layout_params = { layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype, 'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None), 'is_weight': True, # Mark this as a weight tensor
} }
if layout_params['scale'] is not None:
# Add group_size and precision if present in qconfig
if 'group_size' in qconfig:
layout_params['group_size'] = qconfig['group_size']
if 'precision' in qconfig:
layout_params['precision'] = qconfig['precision']
# Handle weight_scale
weight_scale_key = f"{prefix}weight_scale"
weight_scale = state_dict.pop(weight_scale_key, None)
if weight_scale is not None:
layout_params['scale'] = weight_scale
manually_loaded_keys.append(weight_scale_key) manually_loaded_keys.append(weight_scale_key)
# custom_layer_params_keys are loaded into layout_params from state_dict
if 'custom_layer_params_keys' in qconfig:
for param_name in qconfig['custom_layer_params_keys']:
param_key = f"{prefix}{param_name}"
param_value = state_dict.pop(param_key, None)
if param_value is not None:
layout_params[param_name] = param_value.to(device=device).contiguous()
manually_loaded_keys.append(param_key)
else:
logging.warning(f"Missing custom parameter {param_name} for layer {layer_name}")
# parameters are loaded into module attributes from state_dict
for param_name in qconfig["parameters"]:
if param_name in layout_params:
continue # Already loaded via custom_layer_params_keys or weight_scale
param_key = f"{prefix}{param_name}"
param_value = state_dict.pop(param_key, None)
if param_value is not None:
# For standard parameters, store as module attributes
setattr(self, param_name, torch.nn.Parameter(param_value.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
# Create the quantized weight tensor
quantized_weight = QuantizedTensor(weight.to(device=device),
self.layout_type, layout_params)
self.weight_prefix = prefix
self.weight = torch.nn.Parameter( self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), quantized_weight,
requires_grad=False requires_grad=False
) )
self.weight.requires_grad = False
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -646,6 +684,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
getattr(self, 'input_scale', None) is not None and getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)): not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias) return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
@ -696,4 +735,4 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init return disable_weight_init
return manual_cast return manual_cast

View File

@ -46,13 +46,28 @@ def register_generic_util(torch_op):
def _get_layout_from_args(args): def _get_layout_from_args(args):
def _extract_layout(obj):
if isinstance(obj, QuantizedTensor):
return obj._layout_type
# For torch.nn.Parameter wrapping QuantizedTensor, check the data attribute
if isinstance(obj, torch.nn.Parameter):
if isinstance(obj.data, QuantizedTensor):
return obj.data._layout_type
if hasattr(obj.data, "_layout_type"):
return getattr(obj.data, "_layout_type", None)
if hasattr(obj, "_layout_type"):
return getattr(obj, "_layout_type", None)
return None
for arg in args: for arg in args:
if isinstance(arg, QuantizedTensor): layout = _extract_layout(arg)
return arg._layout_type if layout is not None:
elif isinstance(arg, (list, tuple)): return layout
if isinstance(arg, (list, tuple)):
for item in arg: for item in arg:
if isinstance(item, QuantizedTensor): layout = _extract_layout(item)
return item._layout_type if layout is not None:
return layout
return None return None
@ -438,6 +453,46 @@ QUANT_ALGOS = {
"parameters": {"weight_scale", "input_scale"}, "parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout", "comfy_tensor_layout": "TensorCoreFP8Layout",
}, },
"svdquant_int4": {
"storage_t": torch.int8, # Packed 4-bit stored in int8
"parameters": {
"wscales",
"smooth_factor",
"smooth_factor_orig",
"proj_down",
"proj_up",
},
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up"],
"comfy_tensor_layout": "SVDQuantLayout",
"group_size": 64,
"precision": "int4",
},
"svdquant_nvfp4": {
"storage_t": torch.int8, # Packed 4-bit stored in int8
"parameters": {
"wscales",
"smooth_factor",
"smooth_factor_orig",
"proj_down",
"proj_up",
"wtscale",
"wcscales",
},
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up", "wtscale", "wcscales"],
"comfy_tensor_layout": "SVDQuantLayout",
"group_size": 16,
"precision": "nvfp4",
},
"awq_int4": {
"storage_t": torch.int32, # Packed 4-bit stored in int32
"parameters": {
"wscales",
"wzeros",
},
"custom_layer_params_keys": ["wscales", "wzeros"],
"comfy_tensor_layout": "AWQQuantLayout",
"group_size": 64,
},
} }
LAYOUTS = { LAYOUTS = {
@ -571,3 +626,439 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs) return func(*args, **kwargs)
# ==============================================================================
# SVDQuant Layout + Operation Handlers
# ==============================================================================
class SVDQuantLayout(QuantizedLayout):
"""
SVDQuant W4A4 quantization layout.
SVDQuant decomposes linear operations as:
X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
Where:
- proj_up, proj_down: Low-rank factorization of weights
- R: Residual weights (quantized to 4-bit)
- quantize(): 4-bit quantization with smoothing factors
Storage format:
For weights (is_weight=True):
- qdata: Packed quantized residual weights (out_features, in_features // 2), int8
- wscales: Weight quantization scales
- smooth_factor: Smoothing factors for inputs
- proj_down: Low-rank down projection
- proj_up: Low-rank up projection
- group_size: Quantization group size (64 for int4, 16 for nvfp4)
- precision: 'int4' or 'nvfp4'
- rank: SVD rank
- wtscale: Global weight scale (nvfp4 only)
- wcscales: Channel-wise weight scales (nvfp4 only)
- act_unsigned: Whether activations are unsigned (int4 only)
- orig_dtype: Original dtype before quantization
For activations (is_weight=False):
- qdata: Original activation tensor (not quantized yet)
- orig_dtype: Original dtype
- is_weight: False marker
"""
@classmethod
def quantize(cls, tensor, is_weight=True, **kwargs):
"""
For SVDQuant, we don't perform online quantization.
- Weights are pre-quantized offline and loaded from checkpoint
- Activations are stored as-is and quantized during forward pass
"""
orig_dtype = tensor.dtype
if is_weight:
# This shouldn't be called for weights as they're loaded pre-quantized
raise NotImplementedError(
"SVDQuant weights should be loaded pre-quantized from checkpoint, "
"not quantized on-the-fly"
)
else:
# For activations, just store the tensor as-is
# It will be quantized during the linear operation
layout_params = {
'orig_dtype': orig_dtype,
'is_weight': False
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, is_weight=True, orig_dtype=None, **kwargs):
"""
Dequantization for SVDQuant.
- Activations: return as-is (not actually quantized)
- Weights: full dequantization not supported (would need to reconstruct from SVD + residual)
"""
if not is_weight:
# Activations aren't actually quantized, just return them
return qdata.to(orig_dtype) if orig_dtype else qdata
else:
# Full weight dequantization is complex and not typically needed
# Would require: proj_down @ proj_up.T + dequantize(qweight)
raise NotImplementedError(
"Full dequantization of SVDQuant weights is not supported. "
"Use the quantized forward pass instead."
)
@classmethod
def get_plain_tensors(cls, qtensor):
"""Extract the raw tensors needed for SVDQuant computation."""
if qtensor._layout_params.get('is_weight', True):
# For weights, return all the necessary components
return {
'qweight': qtensor._qdata,
'wscales': qtensor._layout_params.get('wscales'),
'smooth_factor': qtensor._layout_params.get('smooth_factor'),
'proj_down': qtensor._layout_params.get('proj_down'),
'proj_up': qtensor._layout_params.get('proj_up'),
'group_size': qtensor._layout_params.get('group_size'),
'precision': qtensor._layout_params.get('precision', 'int4'),
'wtscale': qtensor._layout_params.get('wtscale'),
'wcscales': qtensor._layout_params.get('wcscales'),
'act_unsigned': qtensor._layout_params.get('act_unsigned', False),
}
else:
# For activations, just return the tensor
return qtensor._qdata
@register_layout_op(torch.ops.aten.addmm.default, "SVDQuantLayout")
@register_layout_op(torch.ops.aten.linear.default, "SVDQuantLayout")
def svdquant_linear(func, args, kwargs):
"""
SVDQuant linear operation handler.
Implements: X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
Handles both aten.linear and aten.addmm (which linear decomposes into).
"""
# Handle both linear and addmm calling conventions
if func == torch.ops.aten.addmm.default:
# addmm(bias, input, weight.t()) -> out
bias = args[0] if len(args) > 0 else None
input_tensor = args[1] if len(args) > 1 else None
weight = args[2] if len(args) > 2 else None
# Weight comes transposed in addmm, but SVDQuant stores it non-transposed
# So we need to transpose it back
need_transpose = True
else:
# linear(input, weight, bias) -> out
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
need_transpose = False
# Unwrap Parameter if necessary
if isinstance(weight, torch.nn.Parameter):
weight = weight.data
# Check if weight is SVDQuant quantized
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "SVDQuantLayout":
# Fallback to standard linear
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
if func == torch.ops.aten.addmm.default:
return torch.addmm(bias, input_tensor, weight)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
# Extract weight parameters
weight_params = SVDQuantLayout.get_plain_tensors(weight)
qweight = weight_params['qweight']
wscales = weight_params['wscales']
smooth_factor = weight_params['smooth_factor']
proj_down = weight_params['proj_down']
proj_up = weight_params['proj_up']
group_size = weight_params['group_size']
precision = weight_params['precision']
wtscale = weight_params['wtscale']
wcscales = weight_params['wcscales']
act_unsigned = weight_params['act_unsigned']
# Get activation tensor (dequantize if it's a QuantizedTensor)
if isinstance(input_tensor, QuantizedTensor):
if input_tensor._layout_type == "SVDQuantLayout":
x = SVDQuantLayout.get_plain_tensors(input_tensor)
else:
x = input_tensor.dequantize()
else:
x = input_tensor
# Import nunchaku operations
try:
from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda
except ImportError:
raise ImportError(
"SVDQuant requires the nunchaku library. "
"Install it with: pip install nunchaku"
)
# Handle batch dimensions
original_shape = x.shape
if len(original_shape) == 2:
batch_size, channels = original_shape
seq_len = 1
x = x.view(batch_size, seq_len, channels)
elif len(original_shape) == 3:
batch_size, seq_len, channels = original_shape
else:
raise ValueError(f"SVDQuant linear expects 2D or 3D input, got {len(original_shape)}D")
# Reshape to 2D for computation
x_2d = x.reshape(batch_size * seq_len, channels)
original_batch_size = x_2d.shape[0] # Track original size before padding
# Step 1: Quantize activations and compute low-rank hidden states
# Output: quantized_x, ascales, lora_act_out
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
x_2d,
lora_down=proj_down,
smooth=smooth_factor,
fp4=(precision == "nvfp4"),
pad_size=256
)
# Step 2: Compute quantized GEMM with low-rank residual
# Output shape: (N_padded, out_features) where N_padded may be larger due to padding
out_features = qweight.shape[0]
output = torch.empty(
quantized_x.shape[0],
out_features,
dtype=proj_up.dtype,
device=x.device
)
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=qweight,
out=output,
ascales=ascales,
wscales=wscales,
lora_act_in=lora_act_out,
lora_up=proj_up,
bias=bias,
fp4=(precision == "nvfp4"),
alpha=wtscale,
wcscales=wcscales,
act_unsigned=act_unsigned,
)
# Slice to remove padding and reshape back to original batch dimensions
output = output[:original_batch_size, :] # Remove padding
if len(original_shape) == 2:
output = output.view(batch_size, out_features)
else:
output = output.view(batch_size, seq_len, out_features)
return output
# ==============================================================================
# AWQ Layout + Operation Handlers
# ==============================================================================
class AWQQuantLayout(QuantizedLayout):
"""
AWQ W4A16 quantization layout.
AWQ (Activation-aware Weight Quantization) quantizes weights to 4-bit
while keeping activations in 16-bit precision (float16/bfloat16).
Storage format:
For weights (is_weight=True):
- qdata: Packed quantized weights (out_features // 4, in_features // 2), int32
- wscales: Weight quantization scales (in_features // group_size, out_features)
- wzeros: Weight zero points (in_features // group_size, out_features)
- group_size: Quantization group size (default 64)
- orig_dtype: Original dtype before quantization
For activations (is_weight=False):
- qdata: Original activation tensor (not quantized)
- orig_dtype: Original dtype
- is_weight: False marker
"""
@classmethod
def quantize(cls, tensor, is_weight=True, **kwargs):
"""
For AWQ, we don't perform online quantization.
- Weights are pre-quantized offline and loaded from checkpoint
- Activations remain in 16-bit precision
"""
orig_dtype = tensor.dtype
if is_weight:
# This shouldn't be called for weights as they're loaded pre-quantized
raise NotImplementedError(
"AWQ weights should be loaded pre-quantized from checkpoint, "
"not quantized on-the-fly"
)
else:
# For activations, just store the tensor as-is
layout_params = {
'orig_dtype': orig_dtype,
'is_weight': False
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, is_weight=True, orig_dtype=None, wscales=None, wzeros=None, group_size=64, **kwargs):
"""
Dequantization for AWQ.
- Activations: return as-is (not quantized)
- Weights: unpack and dequantize from 4-bit
"""
if not is_weight:
# Activations aren't quantized, just return them
return qdata.to(orig_dtype) if orig_dtype else qdata
else:
# Dequantize 4-bit weights
# qdata shape: (out_features // 4, in_features // 2), dtype int32
# Output shape should be: (out_features, in_features)
# This is a complex operation that requires unpacking 4-bit values
# For now, we'll raise an error and rely on the quantized forward pass
raise NotImplementedError(
"Full dequantization of AWQ weights is not yet supported. "
"Use the quantized forward pass instead."
)
@classmethod
def get_plain_tensors(cls, qtensor):
"""Extract the raw tensors needed for AWQ computation."""
if qtensor._layout_params.get('is_weight', True):
# For weights, return all the necessary components
return {
'qweight': qtensor._qdata,
'wscales': qtensor._layout_params.get('wscales'),
'wzeros': qtensor._layout_params.get('wzeros'),
'group_size': qtensor._layout_params.get('group_size', 64),
}
else:
# For activations, just return the tensor
return qtensor._qdata
@register_layout_op(torch.ops.aten.addmm.default, "AWQQuantLayout")
@register_layout_op(torch.ops.aten.linear.default, "AWQQuantLayout")
def awq_linear(func, args, kwargs):
"""
AWQ linear operation handler.
Implements W4A16 quantized linear using AWQ format.
Handles both aten.linear and aten.addmm (which linear decomposes into).
"""
# Handle both linear and addmm calling conventions
if func == torch.ops.aten.addmm.default:
# addmm(bias, input, weight.t()) -> out
bias = args[0] if len(args) > 0 else None
input_tensor = args[1] if len(args) > 1 else None
weight = args[2] if len(args) > 2 else None
else:
# linear(input, weight, bias) -> out
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
# Unwrap Parameter if necessary
if isinstance(weight, torch.nn.Parameter):
weight = weight.data
# Check if weight is AWQ quantized
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "AWQQuantLayout":
# Fallback to standard linear
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
if func == torch.ops.aten.addmm.default:
return torch.addmm(bias, input_tensor, weight)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
# Extract weight parameters
weight_params = AWQQuantLayout.get_plain_tensors(weight)
qweight = weight_params['qweight']
wscales = weight_params['wscales']
wzeros = weight_params['wzeros']
group_size = weight_params['group_size']
# Get activation tensor (dequantize if it's a QuantizedTensor)
if isinstance(input_tensor, QuantizedTensor):
if input_tensor._layout_type == "AWQQuantLayout":
x = AWQQuantLayout.get_plain_tensors(input_tensor)
else:
x = input_tensor.dequantize()
else:
x = input_tensor
# Import nunchaku AWQ operation
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
raise ImportError(
"AWQ requires the nunchaku library. "
"Install it with: pip install nunchaku"
)
# Calculate output dimensions from packed weight shape
# qweight shape: (out_features // 4, in_features // 2)
out_features = qweight.shape[0] * 4
in_features = qweight.shape[1] * 2
# Handle batch dimensions - preserve original shape
# Important: nunchaku expects 2D input only, so we reshape 3D to 2D
original_shape = x.shape
if len(original_shape) == 2:
# (batch_size, in_features)
batch_size = original_shape[0]
x_2d = x
#elif len(original_shape) == 3:
# # (batch_size, seq_len, in_features) -> (batch_size * seq_len, in_features)
# batch_size, seq_len, _ = original_shape
# x_2d = x.reshape(batch_size * seq_len, in_features)
else:
raise ValueError(f"AWQ linear expects 2D or 3D input, got {len(original_shape)}D")
# Ensure input is contiguous (required by CUDA kernel)
# Only create a contiguous copy if absolutely necessary
#if not x_2d.is_contiguous():
# x_2d = x_2d.contiguous()
output = awq_gemv_w4a16_cuda(
in_feats=x_2d,
kernel=qweight,
scaling_factors=wscales,
zeros=wzeros,
m=x_2d.shape[0],
n=out_features,
k=in_features,
group_size=group_size,
)
# Add bias if present
if bias is not None:
view_shape = [1] * (output.ndim - 1) + [-1]
output = output + bias.view(view_shape)
# Reshape back to original batch dimensions
#if len(original_shape) == 3:
# output = output.view(batch_size, seq_len, out_features)
return output
LAYOUTS["SVDQuantLayout"] = SVDQuantLayout
LAYOUTS["AWQQuantLayout"] = AWQQuantLayout

377
comfy/svdquant_converter.py Normal file
View File

@ -0,0 +1,377 @@
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
from safetensors import safe_open
from safetensors.torch import save_file
# Note: Fused layer splitting is no longer used
@dataclass
class ConvertedState:
tensors: Dict[str, torch.Tensor]
quant_layers: Dict[str, str]
def _is_svd_prefix(keys: set[str], prefix: str) -> bool:
return (
f"{prefix}.qweight" in keys
and f"{prefix}.smooth_factor" in keys
and f"{prefix}.proj_down" in keys
and f"{prefix}.proj_up" in keys
)
def _is_awq_prefix(keys: set[str], prefix: str) -> bool:
return (
f"{prefix}.qweight" in keys
and f"{prefix}.wscales" in keys
and f"{prefix}.wzeros" in keys
and f"{prefix}.smooth_factor" not in keys # Distinguish from SVDQuant
)
def _detect_svd_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
prefixes = set()
keys = set(state_dict.keys())
for key in keys:
if not key.endswith(".qweight"):
continue
prefix = key[: -len(".qweight")]
if _is_svd_prefix(keys, prefix):
prefixes.add(prefix)
return sorted(prefixes)
def _detect_awq_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
prefixes = set()
keys = set(state_dict.keys())
for key in keys:
if not key.endswith(".qweight"):
continue
prefix = key[: -len(".qweight")]
if _is_awq_prefix(keys, prefix):
prefixes.add(prefix)
return sorted(prefixes)
def _detect_format(wscales: torch.Tensor) -> str:
if wscales.dtype == torch.float8_e4m3fn:
return "svdquant_nvfp4"
return "svdquant_int4"
class _SVDQuantConverter:
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.src = dict(state_dict)
self.dst: Dict[str, torch.Tensor] = {}
self.quant_layers: Dict[str, str] = {}
def convert(self) -> ConvertedState:
prefixes = _detect_svd_prefixes(self.src)
for prefix in prefixes:
self._convert_single(prefix)
for key, tensor in self.src.items():
if key not in self.dst:
self.dst[key] = tensor
return ConvertedState(self.dst, self.quant_layers)
def _pop_tensor(self, key: str) -> torch.Tensor:
try:
return self.src.pop(key)
except KeyError as exc:
raise KeyError(f"Missing key '{key}' in SVDQuant checkpoint") from exc
def _pop_optional(self, key: str) -> torch.Tensor | None:
return self.src.pop(key, None)
def _convert_single(self, prefix: str) -> None:
# Ensure all tensors are contiguous to avoid CUDA alignment issues
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
wscales = self._pop_tensor(f"{prefix}.wscales").contiguous()
self.dst[f"{prefix}.wscales"] = wscales
format_name = _detect_format(wscales)
self.dst[f"{prefix}.smooth_factor"] = self._pop_tensor(f"{prefix}.smooth_factor").contiguous()
self.dst[f"{prefix}.smooth_factor_orig"] = self._pop_tensor(
f"{prefix}.smooth_factor_orig"
).contiguous()
self.dst[f"{prefix}.proj_down"] = self._pop_tensor(f"{prefix}.proj_down").contiguous()
self.dst[f"{prefix}.proj_up"] = self._pop_tensor(f"{prefix}.proj_up").contiguous()
bias = self._pop_optional(f"{prefix}.bias")
if bias is not None:
self.dst[f"{prefix}.bias"] = bias.contiguous()
wtscale = self._pop_optional(f"{prefix}.wtscale")
if wtscale is not None:
self.dst[f"{prefix}.wtscale"] = wtscale.contiguous() if isinstance(wtscale, torch.Tensor) else wtscale
wcscales = self._pop_optional(f"{prefix}.wcscales")
if wcscales is not None:
self.dst[f"{prefix}.wcscales"] = wcscales.contiguous()
self.quant_layers[prefix] = format_name
class _AWQConverter:
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
self.src = dict(state_dict)
self.dst: Dict[str, torch.Tensor] = {}
self.quant_layers: Dict[str, str] = {}
def convert(self) -> ConvertedState:
prefixes = _detect_awq_prefixes(self.src)
for prefix in prefixes:
self._convert_single(prefix)
for key, tensor in self.src.items():
if key not in self.dst:
self.dst[key] = tensor
return ConvertedState(self.dst, self.quant_layers)
def _pop_tensor(self, key: str) -> torch.Tensor:
try:
return self.src.pop(key)
except KeyError as exc:
raise KeyError(f"Missing key '{key}' in AWQ checkpoint") from exc
def _pop_optional(self, key: str) -> torch.Tensor | None:
return self.src.pop(key, None)
def _convert_single(self, prefix: str) -> None:
# Ensure all tensors are contiguous to avoid CUDA alignment issues
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
self.dst[f"{prefix}.wscales"] = self._pop_tensor(f"{prefix}.wscales").contiguous()
self.dst[f"{prefix}.wzeros"] = self._pop_tensor(f"{prefix}.wzeros").contiguous()
bias = self._pop_optional(f"{prefix}.bias")
if bias is not None:
self.dst[f"{prefix}.bias"] = bias.contiguous()
self.quant_layers[prefix] = "awq_int4"
def convert_svdquant_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
return _SVDQuantConverter(state_dict).convert()
def convert_awq_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
return _AWQConverter(state_dict).convert()
def detect_quantization_formats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
"""
Detect quantization formats present in a state dict.
Parameters
----------
state_dict : Dict[str, torch.Tensor]
State dictionary to analyze
Returns
-------
Dict[str, List[str]]
Dictionary mapping format names to lists of layer prefixes
Example: {
"svdquant_int4": ["layer1.attn.qkv", "layer2.mlp.up"],
"svdquant_nvfp4": ["layer3.attn.qkv"],
"awq_int4": ["layer1.mlp.down", "layer4.attn.qkv"]
}
"""
result = {}
# Detect SVDQuant layers
svd_prefixes = _detect_svd_prefixes(state_dict)
if svd_prefixes:
# Determine if int4 or nvfp4 based on wscales dtype
for prefix in svd_prefixes:
wscales_key = f"{prefix}.wscales"
if wscales_key in state_dict:
format_name = _detect_format(state_dict[wscales_key])
if format_name not in result:
result[format_name] = []
result[format_name].append(prefix)
# Detect AWQ layers
awq_prefixes = _detect_awq_prefixes(state_dict)
if awq_prefixes:
result["awq_int4"] = awq_prefixes
return result
def convert_awq_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
) -> Tuple[int, Dict[str, str]]:
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
converted = convert_awq_state_dict(tensors)
# Convert layer format dict to expected metadata format
# From: {"layer": "awq_int4"}
# To: {"layer": {"format": "awq_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers
def convert_svdquant_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
) -> Tuple[int, Dict[str, str]]:
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
converted = convert_svdquant_state_dict(tensors)
# Convert layer format dict to expected metadata format
# From: {"layer": "svdquant_int4"}
# To: {"layer": {"format": "svdquant_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
metadata["model_class"] = "QwenImageTransformer2DModel"
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers
def convert_quantized_file(
input_path: str,
output_path: str,
format_version: str = "1.0",
quant_format: str = "auto",
) -> Tuple[int, Dict[str, str]]:
"""
Auto-detect and convert quantized checkpoint to ComfyUI format.
Supports mixed-format models where some layers are SVDQuant and others are AWQ.
Each layer is independently detected and converted to the appropriate format.
Parameters
----------
input_path : str
Path to input checkpoint file
output_path : str
Path to output checkpoint file
format_version : str, optional
Quantization metadata format version (default: "1.0")
quant_format : str, optional
Quantization format: "auto", "svdquant", or "awq" (default: "auto")
Returns
-------
Tuple[int, Dict[str, str]]
Number of quantized layers and mapping of layer names to formats
"""
with safe_open(input_path, framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
metadata = dict(f.metadata())
# Auto-detect format if needed
if quant_format == "auto":
svd_prefixes = _detect_svd_prefixes(tensors)
awq_prefixes = _detect_awq_prefixes(tensors)
if svd_prefixes and awq_prefixes:
# Mixed format - partition tensors by format and convert separately
# Build sets of all quantized prefixes
all_svd_prefixes = set(svd_prefixes)
all_awq_prefixes = set(awq_prefixes)
# Helper to check if a key belongs to a specific quantized layer
def belongs_to_prefix(key, prefix):
"""Check if key belongs to a specific layer prefix."""
return key == prefix or key.startswith(f"{prefix}.")
def is_svd_key(key):
"""Check if key belongs to any SVDQuant layer."""
return any(belongs_to_prefix(key, prefix) for prefix in all_svd_prefixes)
def is_awq_key(key):
"""Check if key belongs to any AWQ layer."""
return any(belongs_to_prefix(key, prefix) for prefix in all_awq_prefixes)
# Partition tensors by format
svd_tensors = {}
awq_tensors = {}
other_tensors = {}
for key, tensor in tensors.items():
if is_svd_key(key):
svd_tensors[key] = tensor
elif is_awq_key(key):
awq_tensors[key] = tensor
else:
other_tensors[key] = tensor
# Convert each format separately with only its relevant tensors
svd_converted = _SVDQuantConverter(svd_tensors).convert()
awq_converted = _AWQConverter(awq_tensors).convert()
# Merge results - each converter only has its own layer tensors
converted_tensors = {}
# Add SVDQuant converted tensors
converted_tensors.update(svd_converted.tensors)
# Add AWQ converted tensors
converted_tensors.update(awq_converted.tensors)
# Add non-quantized tensors
converted_tensors.update(other_tensors)
# Merge quantization layer metadata
quant_layers = {}
quant_layers.update(svd_converted.quant_layers)
quant_layers.update(awq_converted.quant_layers)
converted = ConvertedState(converted_tensors, quant_layers)
elif svd_prefixes:
converted = convert_svdquant_state_dict(tensors)
elif awq_prefixes:
converted = convert_awq_state_dict(tensors)
else:
raise ValueError("No quantized layers detected in checkpoint")
elif quant_format == "svdquant":
converted = convert_svdquant_state_dict(tensors)
elif quant_format == "awq":
converted = convert_awq_state_dict(tensors)
else:
raise ValueError(f"Unknown quantization format: {quant_format}")
# Convert layer format dict to expected metadata format
# From: {"layer": "awq_int4"}
# To: {"layer": {"format": "awq_int4"}}
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
metadata["_quantization_metadata"] = json.dumps(
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
)
metadata["model_class"] = "QwenImageTransformer2DModel"
save_file(converted.tensors, output_path, metadata=metadata)
return len(converted.quant_layers), converted.quant_layers

View File

@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""
Convert quantized checkpoints (SVDQuant, AWQ, or mixed) into the ComfyUI quantization format.
"""
import argparse
from pathlib import Path
from safetensors import safe_open
from comfy.svdquant_converter import (
convert_quantized_file,
convert_svdquant_file,
convert_awq_file,
detect_quantization_formats,
)
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Convert quantized .safetensors files (SVDQuant, AWQ, or mixed) "
"into the ComfyUI format with per-layer metadata for MixedPrecisionOps."
)
parser.add_argument("input", type=Path, help="Path to the source quantized .safetensors file.")
parser.add_argument(
"-o",
"--output",
type=Path,
help="Destination path for the converted checkpoint. "
"Defaults to <input_name>_comfy.safetensors in the same directory.",
)
parser.add_argument(
"--format-version",
default="1.0",
help="Format version to store inside _quantization_metadata (default: 1.0).",
)
parser.add_argument(
"--format",
choices=["auto", "svdquant", "awq"],
default="auto",
help="Quantization format (default: auto-detect).",
)
parser.add_argument(
"--detect-only",
action="store_true",
help="Only detect and report quantization formats without converting.",
)
return parser
def main() -> None:
parser = _build_parser()
args = parser.parse_args()
input_path = args.input.expanduser().resolve()
# Detect formats if requested
if args.detect_only:
print(f"[Quantization Detector] Analyzing: {input_path}")
with safe_open(str(input_path), framework="pt", device="cpu") as f:
tensors = {key: f.get_tensor(key) for key in f.keys()}
formats = detect_quantization_formats(tensors)
if not formats:
print("[Quantization Detector] No quantized layers detected.")
return
print(f"[Quantization Detector] Detected formats:")
total_layers = 0
for format_name, layer_prefixes in sorted(formats.items()):
print(f"\n {format_name}: {len(layer_prefixes)} layers")
for prefix in sorted(layer_prefixes)[:5]: # Show first 5
print(f" - {prefix}")
if len(layer_prefixes) > 5:
print(f" ... and {len(layer_prefixes) - 5} more")
total_layers += len(layer_prefixes)
print(f"\n[Quantization Detector] Total: {total_layers} quantized layers")
print(f"[Quantization Detector] Use without --detect-only to convert.")
return
# Convert checkpoint
if args.output is None:
output_path = input_path.with_name(f"{input_path.stem}_comfy.safetensors")
else:
output_path = args.output.expanduser().resolve()
layer_count, quant_layers = convert_quantized_file(
str(input_path),
str(output_path),
format_version=args.format_version,
quant_format=args.format,
)
# Group layers by format for display
format_groups = {}
for layer_name, fmt in quant_layers.items():
if fmt not in format_groups:
format_groups[fmt] = []
format_groups[fmt].append(layer_name)
print(f"[Quantization Converter] Converted {layer_count} layers.")
print(f"[Quantization Converter] Output saved to: {output_path}")
print(f"\n[Quantization Converter] Quantized layers by format:")
for fmt, layers in sorted(format_groups.items()):
print(f"\n {fmt}: {len(layers)} layers")
for layer_name in sorted(layers)[:5]: # Show first 5
print(f" - {layer_name}")
if len(layers) > 5:
print(f" ... and {len(layers) - 5} more")
if __name__ == "__main__":
main()

View File

@ -1,7 +1,10 @@
import unittest
import torch
import sys
import os import os
import sys
import unittest
from pathlib import Path
import torch
from safetensors.torch import load_file
# Add comfy to path # Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
@ -13,7 +16,9 @@ from comfy.cli_args import args
if not has_gpu(): if not has_gpu():
args.cpu = True args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout
from comfy.ops import mixed_precision_ops
from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict
class TestQuantizedTensor(unittest.TestCase): class TestQuantizedTensor(unittest.TestCase):
@ -156,6 +161,199 @@ class TestTensorCoreFP8Layout(unittest.TestCase):
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestAWQQuantLayout(unittest.TestCase):
"""Test the AWQQuantLayout implementation"""
def test_awq_layout_creation(self):
"""Test creating an AWQ quantized tensor"""
# AWQ uses pre-quantized weights loaded from checkpoints
# Create dummy AWQ quantized weights
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int32)
self.assertEqual(qt._layout_type, "AWQQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
def test_awq_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for AWQ"""
# AWQ doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
AWQQuantLayout.quantize(float_tensor, is_weight=True)
def test_awq_get_plain_tensors(self):
"""Test extracting plain tensors from AWQ quantized tensor"""
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
plain_tensors = AWQQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('wzeros', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros))
class TestSVDQuantLayout(unittest.TestCase):
"""Test the SVDQuantLayout implementation"""
def test_svdquant_layout_creation(self):
"""Test creating an SVDQuant quantized tensor"""
# SVDQuant uses pre-quantized weights loaded from checkpoints
out_features, in_features = 256, 128
rank = 32
group_size = 64
precision = "int4"
# Create dummy SVDQuant quantized weights (int8 range is -128 to 127)
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': precision,
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int8)
self.assertEqual(qt._layout_type, "SVDQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
self.assertEqual(qt._layout_params['precision'], precision)
def test_svdquant_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for SVDQuant"""
# SVDQuant doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.quantize(float_tensor, is_weight=True)
def test_svdquant_dequantize_not_supported(self):
"""Test that weight dequantization raises NotImplementedError for SVDQuant"""
# Full weight dequantization is not supported (complex operation)
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.dequantize(
qweight,
is_weight=True,
wscales=wscales,
smooth_factor=smooth_factor,
proj_down=proj_down,
proj_up=proj_up,
group_size=group_size,
precision="int4",
orig_dtype=torch.bfloat16
)
def test_svdquant_get_plain_tensors(self):
"""Test extracting plain tensors from SVDQuant quantized tensor"""
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': "int4",
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
plain_tensors = SVDQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('smooth_factor', plain_tensors)
self.assertIn('proj_down', plain_tensors)
self.assertIn('proj_up', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertIn('precision', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor))
self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down))
self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up))
class TestFallbackMechanism(unittest.TestCase): class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations""" """Test fallback for unsupported operations"""
@ -186,5 +384,158 @@ class TestFallbackMechanism(unittest.TestCase):
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
class TestAWQConversion(unittest.TestCase):
"""Test AWQ checkpoint conversion"""
def test_awq_single_layer_conversion(self):
"""Test converting a single AWQ layer"""
in_features, out_features = 128, 256
group_size = 64
# Create AWQ checkpoint format
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.bias": torch.randn(out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check that qweight was renamed to weight
self.assertIn("layer.weight", converted.tensors)
self.assertNotIn("layer.qweight", converted.tensors)
# Check other parameters preserved
self.assertIn("layer.wscales", converted.tensors)
self.assertIn("layer.wzeros", converted.tensors)
self.assertIn("layer.bias", converted.tensors)
# Check quantization metadata
self.assertIn("layer", converted.quant_layers)
self.assertEqual(converted.quant_layers["layer"], "awq_int4")
def test_awq_tensor_shapes(self):
"""Test that converted AWQ tensors have correct shapes"""
in_features, out_features = 3072, 18432
group_size = 64
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check qweight shape (packed 4-bit)
qweight = converted.tensors["layer.weight"]
self.assertEqual(qweight.shape, (out_features // 4, in_features // 2))
self.assertEqual(qweight.dtype, torch.int32)
# Check wscales shape
wscales = converted.tensors["layer.wscales"]
self.assertEqual(wscales.shape, (in_features // group_size, out_features))
self.assertEqual(wscales.dtype, torch.bfloat16)
# Check wzeros shape
wzeros = converted.tensors["layer.wzeros"]
self.assertEqual(wzeros.shape, (in_features // group_size, out_features))
self.assertEqual(wzeros.dtype, torch.bfloat16)
class TestAWQLinearOperation(unittest.TestCase):
"""Test AWQ linear operations with actual nunchaku kernels"""
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_basic(self):
"""Test basic AWQ linear operation by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
bias = torch.randn(out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {})
# Check output shape and dtype
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_2d_input(self):
"""Test AWQ linear with 2D input (batch, features) by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create 2D input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {})
# Check output shape
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()