mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
add svdquant int4 support, modify qwen model to support nunchaku style merged qkv
This commit is contained in:
parent
a17cf1c387
commit
c8794e1155
157
QUANTIZATION.md
157
QUANTIZATION.md
@ -124,6 +124,22 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
|
|||||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||||
|
| svdquant_int4 | int8 (packed 4-bit) | - | - | - | - |
|
||||||
|
| svdquant_nvfp4 | int8 (packed 4-bit) | - | - | - | - |
|
||||||
|
| awq_int4 | int32 (packed 4-bit) | - | - | - | - |
|
||||||
|
|
||||||
|
For SVDQuant formats, additional parameters are stored:
|
||||||
|
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
|
||||||
|
- **smooth_factor**: Smoothing factors for inputs (shape: in_features)
|
||||||
|
- **smooth_factor_orig**: Original smoothing factors (shape: in_features)
|
||||||
|
- **proj_down**: Low-rank down projection (shape: in_features, rank)
|
||||||
|
- **proj_up**: Low-rank up projection (shape: out_features, rank)
|
||||||
|
- **wtscale**: Global weight scale (nvfp4 only, scalar float)
|
||||||
|
- **wcscales**: Channel-wise weight scales (nvfp4 only, shape: out_features)
|
||||||
|
|
||||||
|
For AWQ format, the following parameters are stored:
|
||||||
|
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
|
||||||
|
- **wzeros**: Weight zero points (shape: in_features // group_size, out_features)
|
||||||
|
|
||||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||||
|
|
||||||
@ -139,9 +155,9 @@ Example:
|
|||||||
"_quantization_metadata": {
|
"_quantization_metadata": {
|
||||||
"format_version": "1.0",
|
"format_version": "1.0",
|
||||||
"layers": {
|
"layers": {
|
||||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,4 +181,137 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
|
|||||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||||
|
|
||||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||||
|
|
||||||
|
|
||||||
|
## SVDQuant
|
||||||
|
|
||||||
|
SVDQuant is an advanced 4-bit quantization scheme that decomposes linear operations using low-rank factorization combined with residual quantization:
|
||||||
|
|
||||||
|
```
|
||||||
|
X*W = X * proj_down * proj_up + quantize(X) * quantize(R)
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `proj_down`, `proj_up`: Low-rank factorization matrices of the original weights
|
||||||
|
- `R`: Residual weights (quantized to 4-bit)
|
||||||
|
- `quantize()`: 4-bit quantization with smoothing factors
|
||||||
|
|
||||||
|
### Key Features
|
||||||
|
|
||||||
|
1. **Asymmetric Quantization**: Unlike FP8 where both weights and activations are quantized offline or use the same quantization scheme, SVDQuant:
|
||||||
|
- Quantizes weights offline with multiple parameters stored in the checkpoint
|
||||||
|
- Quantizes activations on-the-fly during forward pass using smoothing factors
|
||||||
|
|
||||||
|
2. **Two Precision Modes**:
|
||||||
|
- `svdquant_int4`: 4-bit integer quantization with group_size=64
|
||||||
|
- `svdquant_nvfp4`: 4-bit floating-point (NVIDIA FP4) with group_size=16, includes additional channel-wise scales
|
||||||
|
|
||||||
|
3. **Low-Rank Optimization**: Separates the easy-to-approximate low-rank component from the hard-to-quantize residual, improving accuracy.
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
SVDQuant requires the `nunchaku` library for optimized CUDA kernels:
|
||||||
|
```bash
|
||||||
|
pip install nunchaku
|
||||||
|
```
|
||||||
|
|
||||||
|
The implementation uses two main operations:
|
||||||
|
- `svdq_quantize_w4a4_act_fuse_lora_cuda`: Quantizes activations and computes low-rank hidden states
|
||||||
|
- `svdq_gemm_w4a4_cuda`: Performs the quantized GEMM with low-rank residual addition
|
||||||
|
|
||||||
|
### Checkpoint Format
|
||||||
|
|
||||||
|
SVDQuant checkpoints contain the standard weight tensor (packed 4-bit residuals in int8) plus additional parameters per quantized layer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"layer_name.weight": tensor, # Packed 4-bit residual weights (out_features, in_features // 2)
|
||||||
|
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
|
||||||
|
"layer_name.smooth_factor": tensor, # Smoothing factors (in_features,)
|
||||||
|
"layer_name.smooth_factor_orig": tensor, # Original smoothing factors (in_features,)
|
||||||
|
"layer_name.proj_down": tensor, # Low-rank down projection (in_features, rank)
|
||||||
|
"layer_name.proj_up": tensor, # Low-rank up projection (out_features, rank)
|
||||||
|
|
||||||
|
# For nvfp4 only:
|
||||||
|
"layer_name.wtscale": float, # Global weight scale
|
||||||
|
"layer_name.wcscales": tensor, # Channel-wise scales (out_features,)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The quantization metadata specifies which layers use SVDQuant:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"_quantization_metadata": {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"model.layers.0.mlp.up_proj": {"format": "svdquant_int4"},
|
||||||
|
"model.layers.0.mlp.down_proj": {"format": "svdquant_int4"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## AWQ
|
||||||
|
|
||||||
|
AWQ (Activation-aware Weight Quantization) is a 4-bit weight quantization scheme that keeps activations in 16-bit precision (W4A16):
|
||||||
|
|
||||||
|
```
|
||||||
|
Y = X @ W_quantized
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `X`: 16-bit activations (float16/bfloat16)
|
||||||
|
- `W_quantized`: 4-bit quantized weights with per-group scales and zero points
|
||||||
|
|
||||||
|
### Key Features
|
||||||
|
|
||||||
|
1. **W4A16 Quantization**:
|
||||||
|
- Quantizes weights to 4-bit while keeping activations in 16-bit
|
||||||
|
- Uses per-group quantization with configurable group size (typically 64)
|
||||||
|
- Stores zero points for asymmetric quantization
|
||||||
|
|
||||||
|
2. **Activation-Aware**:
|
||||||
|
- Quantization is calibrated based on activation statistics
|
||||||
|
- Protects salient weights that are important for accuracy
|
||||||
|
|
||||||
|
3. **Hardware Efficient**:
|
||||||
|
- Optimized for GPU inference
|
||||||
|
- Significantly reduces memory footprint
|
||||||
|
- Increases throughput with specialized kernels
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
AWQ requires the `nunchaku` library for optimized CUDA kernels:
|
||||||
|
```bash
|
||||||
|
pip install nunchaku
|
||||||
|
```
|
||||||
|
|
||||||
|
The implementation uses the `awq_gemv_w4a16_cuda` kernel for efficient W4A16 matrix multiplication.
|
||||||
|
|
||||||
|
### Checkpoint Format
|
||||||
|
|
||||||
|
AWQ checkpoints contain the standard weight tensor (packed 4-bit weights in int32) plus additional parameters per quantized layer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"layer_name.weight": tensor, # Packed 4-bit weights (out_features // 4, in_features // 2)
|
||||||
|
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
|
||||||
|
"layer_name.wzeros": tensor, # Zero points (in_features // group_size, out_features)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The quantization metadata specifies which layers use AWQ:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"_quantization_metadata": {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"model.layers.0.mlp.up_proj": {"format": "awq_int4"},
|
||||||
|
"model.layers.0.mlp.down_proj": {"format": "awq_int4"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
@ -4,7 +4,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
@ -12,8 +11,9 @@ import comfy.ldm.common_dit
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from comfy.ldm.flux.math import apply_rope1
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device)
|
||||||
self.approximate = approximate
|
self.approximate = approximate
|
||||||
@ -33,7 +33,9 @@ class FeedForward(nn.Module):
|
|||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
inner_dim=None,
|
inner_dim=None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype=None, device=None, operations=None
|
dtype=None, device=None, operations=None,
|
||||||
|
svdquant_format=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if inner_dim is None:
|
if inner_dim is None:
|
||||||
@ -41,7 +43,7 @@ class FeedForward(nn.Module):
|
|||||||
dim_out = dim_out if dim_out is not None else dim
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
self.net = nn.ModuleList([])
|
self.net = nn.ModuleList([])
|
||||||
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations))
|
self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations, **kwargs))
|
||||||
self.net.append(nn.Dropout(dropout))
|
self.net.append(nn.Dropout(dropout))
|
||||||
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
|
self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device))
|
||||||
|
|
||||||
@ -92,7 +94,9 @@ class Attention(nn.Module):
|
|||||||
out_context_dim: int = None,
|
out_context_dim: int = None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None,
|
||||||
|
svdquant_format=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
@ -109,21 +113,30 @@ class Attention(nn.Module):
|
|||||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.svdquant_format = svdquant_format
|
||||||
|
|
||||||
# Image stream projections
|
# Image stream projections
|
||||||
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
if self.svdquant_format: # svdq merged qkv for better perf
|
||||||
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
self.to_qkv = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
else:
|
||||||
|
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
# Text stream projections
|
# Text stream projections
|
||||||
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
if self.svdquant_format:
|
||||||
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
self.add_qkv_proj = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device)
|
||||||
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
else:
|
||||||
|
self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
# Output projections
|
# Output projections
|
||||||
self.to_out = nn.ModuleList([
|
self.to_out = nn.ModuleList([
|
||||||
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
|
operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -140,29 +153,64 @@ class Attention(nn.Module):
|
|||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
if self.svdquant_format:
|
||||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
img_qkv = self.to_qkv(hidden_states)
|
||||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
|
||||||
|
# Reshape for multi-head attention to [B, L, H, D]
|
||||||
|
img_query = img_query.unflatten(-1, (self.heads, -1)) # [B, L, H, D]
|
||||||
|
img_key = img_key.unflatten(-1, (self.heads, -1))
|
||||||
|
img_value = img_value.unflatten(-1, (self.heads, -1))
|
||||||
|
else:
|
||||||
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||||
|
if self.svdquant_format:
|
||||||
|
txt_qkv = self.add_qkv_proj(encoder_hidden_states)
|
||||||
|
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
|
||||||
|
# Reshape for multi-head attention to [B, L, H, D]
|
||||||
|
txt_query = txt_query.unflatten(-1, (self.heads, -1))
|
||||||
|
txt_key = txt_key.unflatten(-1, (self.heads, -1))
|
||||||
|
txt_value = txt_value.unflatten(-1, (self.heads, -1))
|
||||||
|
else:
|
||||||
|
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
|
||||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
|
||||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
|
||||||
|
|
||||||
img_query = self.norm_q(img_query)
|
img_query = self.norm_q(img_query)
|
||||||
img_key = self.norm_k(img_key)
|
img_key = self.norm_k(img_key)
|
||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
if self.svdquant_format:
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
# Concatenate image and text streams for joint attention
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||||
|
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||||
|
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||||
|
|
||||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
# Apply rotary embeddings to concatenated tensors
|
||||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
# Flatten to [B, L, H*D] for attention
|
||||||
attention_mask, transformer_options=transformer_options,
|
joint_query = joint_query.flatten(start_dim=2)
|
||||||
skip_reshape=True)
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
|
joint_hidden_states = optimized_attention_masked(
|
||||||
|
joint_query, joint_key, joint_value, self.heads, attention_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||||
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
|
attention_mask, transformer_options=transformer_options,
|
||||||
|
skip_reshape=True)
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -183,28 +231,38 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None,
|
||||||
|
scale_shift: float = None,
|
||||||
|
svdquant_format=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.attention_head_dim = attention_head_dim
|
self.attention_head_dim = attention_head_dim
|
||||||
|
self.svdquant_format = svdquant_format
|
||||||
|
# For svdquant, scale_shift should be 0 as the shift is fused into weights
|
||||||
|
if scale_shift is None:
|
||||||
|
scale_shift = 0.0 if self.svdquant_format else 1.0
|
||||||
|
self.scale_shift = scale_shift
|
||||||
|
|
||||||
self.img_mod = nn.Sequential(
|
self.img_mod = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.txt_mod = nn.Sequential(
|
self.txt_mod = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
||||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
@ -216,11 +274,18 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
|
svdquant_format=svdquant_format,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
if self.svdquant_format:
|
||||||
|
if self.scale_shift != 0:
|
||||||
|
scale.add_(self.scale_shift)
|
||||||
|
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -233,21 +298,42 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
txt_mod_params = self.txt_mod(temb)
|
txt_mod_params = self.txt_mod(temb)
|
||||||
|
|
||||||
|
# Nunchaku's mod_params layout is [B, dim*6] with different ordering
|
||||||
|
# Need to reshape from [B, dim*6] to correct layout
|
||||||
|
|
||||||
|
if self.svdquant_format:
|
||||||
|
img_mod_params = (
|
||||||
|
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
|
||||||
|
)
|
||||||
|
txt_mod_params = (
|
||||||
|
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
|
||||||
|
)
|
||||||
|
|
||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||||
del img_mod1
|
del img_mod1
|
||||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
del txt_mod1
|
del txt_mod1
|
||||||
|
|
||||||
img_attn_output, txt_attn_output = self.attn(
|
if self.svdquant_format:
|
||||||
hidden_states=img_modulated,
|
img_attn_output, txt_attn_output = self.attn(
|
||||||
encoder_hidden_states=txt_modulated,
|
hidden_states=img_modulated,
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states=txt_modulated,
|
||||||
image_rotary_emb=image_rotary_emb,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
transformer_options=transformer_options,
|
image_rotary_emb=image_rotary_emb,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
img_attn_output, txt_attn_output = self.attn(
|
||||||
|
hidden_states=img_modulated,
|
||||||
|
encoder_hidden_states=txt_modulated,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
del img_modulated
|
del img_modulated
|
||||||
del txt_modulated
|
del txt_modulated
|
||||||
|
|
||||||
@ -258,6 +344,8 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
del img_gate1
|
del img_gate1
|
||||||
del txt_gate1
|
del txt_gate1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||||
|
|
||||||
@ -307,7 +395,15 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
|
scale_shift: float = None,
|
||||||
|
svdquant_format=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# For svdquant, scale_shift should be 0 as the shift is fused into weights
|
||||||
|
self.svdquant_format = svdquant_format
|
||||||
|
|
||||||
|
if scale_shift is None:
|
||||||
|
scale_shift = 0.0 if self.svdquant_format else 1.0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
@ -336,7 +432,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
attention_head_dim=attention_head_dim,
|
attention_head_dim=attention_head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations,
|
||||||
|
scale_shift=scale_shift,
|
||||||
|
svdquant_format=svdquant_format,
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
@ -384,10 +483,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
#from safetensors import safe_open
|
||||||
|
#with safe_open("/root/nck_x.safetensors", framework="pt", device="cuda") as f:
|
||||||
|
# x = f.get_tensor("nck_x")
|
||||||
timestep = timesteps
|
timestep = timesteps
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
encoder_hidden_states_mask = attention_mask
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
@ -419,7 +520,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
if self.svdquant_format:
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
else:
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
@ -441,6 +545,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -623,6 +623,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
|
||||||
|
# Add SVDQuant linear support
|
||||||
|
if '{}transformer_blocks.0.attn.add_qkv_proj.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
# try import nunchaku:
|
||||||
|
try:
|
||||||
|
from nunchaku.models.linear import SVDQW4A4Linear
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"SVDQuant requires the nunchaku library. "
|
||||||
|
"Please follow the instructions in https://nunchaku.tech/docs/nunchaku/installation/installation.html to install nunchaku"
|
||||||
|
)
|
||||||
|
|
||||||
|
dit_config["svdquant_format"] = True
|
||||||
|
|
||||||
|
if metadata is not None and 'config' in metadata:
|
||||||
|
if 'quantization_config' in metadata:
|
||||||
|
import json
|
||||||
|
metadata_quantization_config = json.loads(metadata['quantization_config'])
|
||||||
|
if 'weight' in metadata_quantization_config:
|
||||||
|
if metadata_quantization_config["weight"]["dtype"] == "fp4_e2m1_all":
|
||||||
|
if metadata_quantization_config["weight"]["group_size"] == 16:
|
||||||
|
dit_config['precision'] = "nvfp4"
|
||||||
|
elif metadata_quantization_config["weight"]["dtype"] == "int4":
|
||||||
|
dit_config['precision'] = "int4"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
|
|||||||
67
comfy/ops.py
67
comfy/ops.py
@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@ -582,11 +583,17 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
device = self.factory_kwargs["device"]
|
device = self.factory_kwargs["device"]
|
||||||
|
if device is None and self.bias is not None:
|
||||||
|
device = self.bias.device
|
||||||
|
|
||||||
layer_name = prefix.rstrip('.')
|
layer_name = prefix.rstrip('.')
|
||||||
weight_key = f"{prefix}weight"
|
weight_key = f"{prefix}weight"
|
||||||
weight = state_dict.pop(weight_key, None)
|
weight = state_dict.pop(weight_key, None)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = weight.device
|
||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
@ -600,27 +607,58 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
qconfig = QUANT_ALGOS[quant_format]
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
# Build layout_params - start with basic parameters
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(weight_scale_key, None),
|
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
'block_size': qconfig.get("group_size", None),
|
'is_weight': True, # Mark this as a weight tensor
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
|
||||||
|
# Add group_size and precision if present in qconfig
|
||||||
|
if 'group_size' in qconfig:
|
||||||
|
layout_params['group_size'] = qconfig['group_size']
|
||||||
|
if 'precision' in qconfig:
|
||||||
|
layout_params['precision'] = qconfig['precision']
|
||||||
|
|
||||||
|
# Handle weight_scale
|
||||||
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
|
weight_scale = state_dict.pop(weight_scale_key, None)
|
||||||
|
if weight_scale is not None:
|
||||||
|
layout_params['scale'] = weight_scale
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
|
# custom_layer_params_keys are loaded into layout_params from state_dict
|
||||||
|
if 'custom_layer_params_keys' in qconfig:
|
||||||
|
for param_name in qconfig['custom_layer_params_keys']:
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
param_value = state_dict.pop(param_key, None)
|
||||||
|
if param_value is not None:
|
||||||
|
layout_params[param_name] = param_value.to(device=device).contiguous()
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Missing custom parameter {param_name} for layer {layer_name}")
|
||||||
|
|
||||||
|
# parameters are loaded into module attributes from state_dict
|
||||||
|
for param_name in qconfig["parameters"]:
|
||||||
|
if param_name in layout_params:
|
||||||
|
continue # Already loaded via custom_layer_params_keys or weight_scale
|
||||||
|
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
param_value = state_dict.pop(param_key, None)
|
||||||
|
if param_value is not None:
|
||||||
|
# For standard parameters, store as module attributes
|
||||||
|
setattr(self, param_name, torch.nn.Parameter(param_value.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
|
# Create the quantized weight tensor
|
||||||
|
quantized_weight = QuantizedTensor(weight.to(device=device),
|
||||||
|
self.layout_type, layout_params)
|
||||||
|
|
||||||
|
self.weight_prefix = prefix
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
quantized_weight,
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
self.weight.requires_grad = False
|
||||||
for param_name in qconfig["parameters"]:
|
|
||||||
param_key = f"{prefix}{param_name}"
|
|
||||||
_v = state_dict.pop(param_key, None)
|
|
||||||
if _v is None:
|
|
||||||
continue
|
|
||||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
|
||||||
manually_loaded_keys.append(param_key)
|
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
@ -646,6 +684,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
getattr(self, 'input_scale', None) is not None and
|
getattr(self, 'input_scale', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
|
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
@ -696,4 +735,4 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
@ -46,13 +46,28 @@ def register_generic_util(torch_op):
|
|||||||
|
|
||||||
|
|
||||||
def _get_layout_from_args(args):
|
def _get_layout_from_args(args):
|
||||||
|
def _extract_layout(obj):
|
||||||
|
if isinstance(obj, QuantizedTensor):
|
||||||
|
return obj._layout_type
|
||||||
|
# For torch.nn.Parameter wrapping QuantizedTensor, check the data attribute
|
||||||
|
if isinstance(obj, torch.nn.Parameter):
|
||||||
|
if isinstance(obj.data, QuantizedTensor):
|
||||||
|
return obj.data._layout_type
|
||||||
|
if hasattr(obj.data, "_layout_type"):
|
||||||
|
return getattr(obj.data, "_layout_type", None)
|
||||||
|
if hasattr(obj, "_layout_type"):
|
||||||
|
return getattr(obj, "_layout_type", None)
|
||||||
|
return None
|
||||||
|
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, QuantizedTensor):
|
layout = _extract_layout(arg)
|
||||||
return arg._layout_type
|
if layout is not None:
|
||||||
elif isinstance(arg, (list, tuple)):
|
return layout
|
||||||
|
if isinstance(arg, (list, tuple)):
|
||||||
for item in arg:
|
for item in arg:
|
||||||
if isinstance(item, QuantizedTensor):
|
layout = _extract_layout(item)
|
||||||
return item._layout_type
|
if layout is not None:
|
||||||
|
return layout
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -438,6 +453,46 @@ QUANT_ALGOS = {
|
|||||||
"parameters": {"weight_scale", "input_scale"},
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||||
},
|
},
|
||||||
|
"svdquant_int4": {
|
||||||
|
"storage_t": torch.int8, # Packed 4-bit stored in int8
|
||||||
|
"parameters": {
|
||||||
|
"wscales",
|
||||||
|
"smooth_factor",
|
||||||
|
"smooth_factor_orig",
|
||||||
|
"proj_down",
|
||||||
|
"proj_up",
|
||||||
|
},
|
||||||
|
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up"],
|
||||||
|
"comfy_tensor_layout": "SVDQuantLayout",
|
||||||
|
"group_size": 64,
|
||||||
|
"precision": "int4",
|
||||||
|
},
|
||||||
|
"svdquant_nvfp4": {
|
||||||
|
"storage_t": torch.int8, # Packed 4-bit stored in int8
|
||||||
|
"parameters": {
|
||||||
|
"wscales",
|
||||||
|
"smooth_factor",
|
||||||
|
"smooth_factor_orig",
|
||||||
|
"proj_down",
|
||||||
|
"proj_up",
|
||||||
|
"wtscale",
|
||||||
|
"wcscales",
|
||||||
|
},
|
||||||
|
"custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up", "wtscale", "wcscales"],
|
||||||
|
"comfy_tensor_layout": "SVDQuantLayout",
|
||||||
|
"group_size": 16,
|
||||||
|
"precision": "nvfp4",
|
||||||
|
},
|
||||||
|
"awq_int4": {
|
||||||
|
"storage_t": torch.int32, # Packed 4-bit stored in int32
|
||||||
|
"parameters": {
|
||||||
|
"wscales",
|
||||||
|
"wzeros",
|
||||||
|
},
|
||||||
|
"custom_layer_params_keys": ["wscales", "wzeros"],
|
||||||
|
"comfy_tensor_layout": "AWQQuantLayout",
|
||||||
|
"group_size": 64,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
LAYOUTS = {
|
LAYOUTS = {
|
||||||
@ -571,3 +626,439 @@ def fp8_func(func, args, kwargs):
|
|||||||
ar[0] = plain_input
|
ar[0] = plain_input
|
||||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# SVDQuant Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class SVDQuantLayout(QuantizedLayout):
|
||||||
|
"""
|
||||||
|
SVDQuant W4A4 quantization layout.
|
||||||
|
|
||||||
|
SVDQuant decomposes linear operations as:
|
||||||
|
X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- proj_up, proj_down: Low-rank factorization of weights
|
||||||
|
- R: Residual weights (quantized to 4-bit)
|
||||||
|
- quantize(): 4-bit quantization with smoothing factors
|
||||||
|
|
||||||
|
Storage format:
|
||||||
|
For weights (is_weight=True):
|
||||||
|
- qdata: Packed quantized residual weights (out_features, in_features // 2), int8
|
||||||
|
- wscales: Weight quantization scales
|
||||||
|
- smooth_factor: Smoothing factors for inputs
|
||||||
|
- proj_down: Low-rank down projection
|
||||||
|
- proj_up: Low-rank up projection
|
||||||
|
- group_size: Quantization group size (64 for int4, 16 for nvfp4)
|
||||||
|
- precision: 'int4' or 'nvfp4'
|
||||||
|
- rank: SVD rank
|
||||||
|
- wtscale: Global weight scale (nvfp4 only)
|
||||||
|
- wcscales: Channel-wise weight scales (nvfp4 only)
|
||||||
|
- act_unsigned: Whether activations are unsigned (int4 only)
|
||||||
|
- orig_dtype: Original dtype before quantization
|
||||||
|
|
||||||
|
For activations (is_weight=False):
|
||||||
|
- qdata: Original activation tensor (not quantized yet)
|
||||||
|
- orig_dtype: Original dtype
|
||||||
|
- is_weight: False marker
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, is_weight=True, **kwargs):
|
||||||
|
"""
|
||||||
|
For SVDQuant, we don't perform online quantization.
|
||||||
|
- Weights are pre-quantized offline and loaded from checkpoint
|
||||||
|
- Activations are stored as-is and quantized during forward pass
|
||||||
|
"""
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if is_weight:
|
||||||
|
# This shouldn't be called for weights as they're loaded pre-quantized
|
||||||
|
raise NotImplementedError(
|
||||||
|
"SVDQuant weights should be loaded pre-quantized from checkpoint, "
|
||||||
|
"not quantized on-the-fly"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For activations, just store the tensor as-is
|
||||||
|
# It will be quantized during the linear operation
|
||||||
|
layout_params = {
|
||||||
|
'orig_dtype': orig_dtype,
|
||||||
|
'is_weight': False
|
||||||
|
}
|
||||||
|
return tensor, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, is_weight=True, orig_dtype=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Dequantization for SVDQuant.
|
||||||
|
- Activations: return as-is (not actually quantized)
|
||||||
|
- Weights: full dequantization not supported (would need to reconstruct from SVD + residual)
|
||||||
|
"""
|
||||||
|
if not is_weight:
|
||||||
|
# Activations aren't actually quantized, just return them
|
||||||
|
return qdata.to(orig_dtype) if orig_dtype else qdata
|
||||||
|
else:
|
||||||
|
# Full weight dequantization is complex and not typically needed
|
||||||
|
# Would require: proj_down @ proj_up.T + dequantize(qweight)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Full dequantization of SVDQuant weights is not supported. "
|
||||||
|
"Use the quantized forward pass instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
"""Extract the raw tensors needed for SVDQuant computation."""
|
||||||
|
if qtensor._layout_params.get('is_weight', True):
|
||||||
|
# For weights, return all the necessary components
|
||||||
|
return {
|
||||||
|
'qweight': qtensor._qdata,
|
||||||
|
'wscales': qtensor._layout_params.get('wscales'),
|
||||||
|
'smooth_factor': qtensor._layout_params.get('smooth_factor'),
|
||||||
|
'proj_down': qtensor._layout_params.get('proj_down'),
|
||||||
|
'proj_up': qtensor._layout_params.get('proj_up'),
|
||||||
|
'group_size': qtensor._layout_params.get('group_size'),
|
||||||
|
'precision': qtensor._layout_params.get('precision', 'int4'),
|
||||||
|
'wtscale': qtensor._layout_params.get('wtscale'),
|
||||||
|
'wcscales': qtensor._layout_params.get('wcscales'),
|
||||||
|
'act_unsigned': qtensor._layout_params.get('act_unsigned', False),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# For activations, just return the tensor
|
||||||
|
return qtensor._qdata
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.addmm.default, "SVDQuantLayout")
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "SVDQuantLayout")
|
||||||
|
def svdquant_linear(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
SVDQuant linear operation handler.
|
||||||
|
|
||||||
|
Implements: X*W = X * proj_up * proj_down + quantize(X) * quantize(R)
|
||||||
|
|
||||||
|
Handles both aten.linear and aten.addmm (which linear decomposes into).
|
||||||
|
"""
|
||||||
|
# Handle both linear and addmm calling conventions
|
||||||
|
if func == torch.ops.aten.addmm.default:
|
||||||
|
# addmm(bias, input, weight.t()) -> out
|
||||||
|
bias = args[0] if len(args) > 0 else None
|
||||||
|
input_tensor = args[1] if len(args) > 1 else None
|
||||||
|
weight = args[2] if len(args) > 2 else None
|
||||||
|
# Weight comes transposed in addmm, but SVDQuant stores it non-transposed
|
||||||
|
# So we need to transpose it back
|
||||||
|
need_transpose = True
|
||||||
|
else:
|
||||||
|
# linear(input, weight, bias) -> out
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
need_transpose = False
|
||||||
|
|
||||||
|
# Unwrap Parameter if necessary
|
||||||
|
if isinstance(weight, torch.nn.Parameter):
|
||||||
|
weight = weight.data
|
||||||
|
|
||||||
|
# Check if weight is SVDQuant quantized
|
||||||
|
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "SVDQuantLayout":
|
||||||
|
# Fallback to standard linear
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
input_tensor = input_tensor.dequantize()
|
||||||
|
if func == torch.ops.aten.addmm.default:
|
||||||
|
return torch.addmm(bias, input_tensor, weight)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
# Extract weight parameters
|
||||||
|
weight_params = SVDQuantLayout.get_plain_tensors(weight)
|
||||||
|
qweight = weight_params['qweight']
|
||||||
|
wscales = weight_params['wscales']
|
||||||
|
smooth_factor = weight_params['smooth_factor']
|
||||||
|
proj_down = weight_params['proj_down']
|
||||||
|
proj_up = weight_params['proj_up']
|
||||||
|
group_size = weight_params['group_size']
|
||||||
|
precision = weight_params['precision']
|
||||||
|
wtscale = weight_params['wtscale']
|
||||||
|
wcscales = weight_params['wcscales']
|
||||||
|
act_unsigned = weight_params['act_unsigned']
|
||||||
|
|
||||||
|
# Get activation tensor (dequantize if it's a QuantizedTensor)
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
if input_tensor._layout_type == "SVDQuantLayout":
|
||||||
|
x = SVDQuantLayout.get_plain_tensors(input_tensor)
|
||||||
|
else:
|
||||||
|
x = input_tensor.dequantize()
|
||||||
|
else:
|
||||||
|
x = input_tensor
|
||||||
|
|
||||||
|
# Import nunchaku operations
|
||||||
|
try:
|
||||||
|
from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
|
||||||
|
from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"SVDQuant requires the nunchaku library. "
|
||||||
|
"Install it with: pip install nunchaku"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle batch dimensions
|
||||||
|
original_shape = x.shape
|
||||||
|
if len(original_shape) == 2:
|
||||||
|
batch_size, channels = original_shape
|
||||||
|
seq_len = 1
|
||||||
|
x = x.view(batch_size, seq_len, channels)
|
||||||
|
elif len(original_shape) == 3:
|
||||||
|
batch_size, seq_len, channels = original_shape
|
||||||
|
else:
|
||||||
|
raise ValueError(f"SVDQuant linear expects 2D or 3D input, got {len(original_shape)}D")
|
||||||
|
|
||||||
|
# Reshape to 2D for computation
|
||||||
|
x_2d = x.reshape(batch_size * seq_len, channels)
|
||||||
|
original_batch_size = x_2d.shape[0] # Track original size before padding
|
||||||
|
|
||||||
|
# Step 1: Quantize activations and compute low-rank hidden states
|
||||||
|
# Output: quantized_x, ascales, lora_act_out
|
||||||
|
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
|
||||||
|
x_2d,
|
||||||
|
lora_down=proj_down,
|
||||||
|
smooth=smooth_factor,
|
||||||
|
fp4=(precision == "nvfp4"),
|
||||||
|
pad_size=256
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2: Compute quantized GEMM with low-rank residual
|
||||||
|
# Output shape: (N_padded, out_features) where N_padded may be larger due to padding
|
||||||
|
out_features = qweight.shape[0]
|
||||||
|
output = torch.empty(
|
||||||
|
quantized_x.shape[0],
|
||||||
|
out_features,
|
||||||
|
dtype=proj_up.dtype,
|
||||||
|
device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
svdq_gemm_w4a4_cuda(
|
||||||
|
act=quantized_x,
|
||||||
|
wgt=qweight,
|
||||||
|
out=output,
|
||||||
|
ascales=ascales,
|
||||||
|
wscales=wscales,
|
||||||
|
lora_act_in=lora_act_out,
|
||||||
|
lora_up=proj_up,
|
||||||
|
bias=bias,
|
||||||
|
fp4=(precision == "nvfp4"),
|
||||||
|
alpha=wtscale,
|
||||||
|
wcscales=wcscales,
|
||||||
|
act_unsigned=act_unsigned,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slice to remove padding and reshape back to original batch dimensions
|
||||||
|
output = output[:original_batch_size, :] # Remove padding
|
||||||
|
if len(original_shape) == 2:
|
||||||
|
output = output.view(batch_size, out_features)
|
||||||
|
else:
|
||||||
|
output = output.view(batch_size, seq_len, out_features)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# AWQ Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AWQQuantLayout(QuantizedLayout):
|
||||||
|
"""
|
||||||
|
AWQ W4A16 quantization layout.
|
||||||
|
|
||||||
|
AWQ (Activation-aware Weight Quantization) quantizes weights to 4-bit
|
||||||
|
while keeping activations in 16-bit precision (float16/bfloat16).
|
||||||
|
|
||||||
|
Storage format:
|
||||||
|
For weights (is_weight=True):
|
||||||
|
- qdata: Packed quantized weights (out_features // 4, in_features // 2), int32
|
||||||
|
- wscales: Weight quantization scales (in_features // group_size, out_features)
|
||||||
|
- wzeros: Weight zero points (in_features // group_size, out_features)
|
||||||
|
- group_size: Quantization group size (default 64)
|
||||||
|
- orig_dtype: Original dtype before quantization
|
||||||
|
|
||||||
|
For activations (is_weight=False):
|
||||||
|
- qdata: Original activation tensor (not quantized)
|
||||||
|
- orig_dtype: Original dtype
|
||||||
|
- is_weight: False marker
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, is_weight=True, **kwargs):
|
||||||
|
"""
|
||||||
|
For AWQ, we don't perform online quantization.
|
||||||
|
- Weights are pre-quantized offline and loaded from checkpoint
|
||||||
|
- Activations remain in 16-bit precision
|
||||||
|
"""
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if is_weight:
|
||||||
|
# This shouldn't be called for weights as they're loaded pre-quantized
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AWQ weights should be loaded pre-quantized from checkpoint, "
|
||||||
|
"not quantized on-the-fly"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For activations, just store the tensor as-is
|
||||||
|
layout_params = {
|
||||||
|
'orig_dtype': orig_dtype,
|
||||||
|
'is_weight': False
|
||||||
|
}
|
||||||
|
return tensor, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, is_weight=True, orig_dtype=None, wscales=None, wzeros=None, group_size=64, **kwargs):
|
||||||
|
"""
|
||||||
|
Dequantization for AWQ.
|
||||||
|
- Activations: return as-is (not quantized)
|
||||||
|
- Weights: unpack and dequantize from 4-bit
|
||||||
|
"""
|
||||||
|
if not is_weight:
|
||||||
|
# Activations aren't quantized, just return them
|
||||||
|
return qdata.to(orig_dtype) if orig_dtype else qdata
|
||||||
|
else:
|
||||||
|
# Dequantize 4-bit weights
|
||||||
|
# qdata shape: (out_features // 4, in_features // 2), dtype int32
|
||||||
|
# Output shape should be: (out_features, in_features)
|
||||||
|
|
||||||
|
# This is a complex operation that requires unpacking 4-bit values
|
||||||
|
# For now, we'll raise an error and rely on the quantized forward pass
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Full dequantization of AWQ weights is not yet supported. "
|
||||||
|
"Use the quantized forward pass instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
"""Extract the raw tensors needed for AWQ computation."""
|
||||||
|
if qtensor._layout_params.get('is_weight', True):
|
||||||
|
# For weights, return all the necessary components
|
||||||
|
return {
|
||||||
|
'qweight': qtensor._qdata,
|
||||||
|
'wscales': qtensor._layout_params.get('wscales'),
|
||||||
|
'wzeros': qtensor._layout_params.get('wzeros'),
|
||||||
|
'group_size': qtensor._layout_params.get('group_size', 64),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# For activations, just return the tensor
|
||||||
|
return qtensor._qdata
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.addmm.default, "AWQQuantLayout")
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "AWQQuantLayout")
|
||||||
|
def awq_linear(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
AWQ linear operation handler.
|
||||||
|
|
||||||
|
Implements W4A16 quantized linear using AWQ format.
|
||||||
|
|
||||||
|
Handles both aten.linear and aten.addmm (which linear decomposes into).
|
||||||
|
"""
|
||||||
|
# Handle both linear and addmm calling conventions
|
||||||
|
if func == torch.ops.aten.addmm.default:
|
||||||
|
# addmm(bias, input, weight.t()) -> out
|
||||||
|
bias = args[0] if len(args) > 0 else None
|
||||||
|
input_tensor = args[1] if len(args) > 1 else None
|
||||||
|
weight = args[2] if len(args) > 2 else None
|
||||||
|
else:
|
||||||
|
# linear(input, weight, bias) -> out
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
|
# Unwrap Parameter if necessary
|
||||||
|
if isinstance(weight, torch.nn.Parameter):
|
||||||
|
weight = weight.data
|
||||||
|
|
||||||
|
# Check if weight is AWQ quantized
|
||||||
|
if not isinstance(weight, QuantizedTensor) or weight._layout_type != "AWQQuantLayout":
|
||||||
|
# Fallback to standard linear
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
input_tensor = input_tensor.dequantize()
|
||||||
|
if func == torch.ops.aten.addmm.default:
|
||||||
|
return torch.addmm(bias, input_tensor, weight)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
# Extract weight parameters
|
||||||
|
weight_params = AWQQuantLayout.get_plain_tensors(weight)
|
||||||
|
qweight = weight_params['qweight']
|
||||||
|
wscales = weight_params['wscales']
|
||||||
|
wzeros = weight_params['wzeros']
|
||||||
|
group_size = weight_params['group_size']
|
||||||
|
|
||||||
|
# Get activation tensor (dequantize if it's a QuantizedTensor)
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
if input_tensor._layout_type == "AWQQuantLayout":
|
||||||
|
x = AWQQuantLayout.get_plain_tensors(input_tensor)
|
||||||
|
else:
|
||||||
|
x = input_tensor.dequantize()
|
||||||
|
else:
|
||||||
|
x = input_tensor
|
||||||
|
|
||||||
|
# Import nunchaku AWQ operation
|
||||||
|
try:
|
||||||
|
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"AWQ requires the nunchaku library. "
|
||||||
|
"Install it with: pip install nunchaku"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate output dimensions from packed weight shape
|
||||||
|
# qweight shape: (out_features // 4, in_features // 2)
|
||||||
|
out_features = qweight.shape[0] * 4
|
||||||
|
in_features = qweight.shape[1] * 2
|
||||||
|
|
||||||
|
|
||||||
|
# Handle batch dimensions - preserve original shape
|
||||||
|
# Important: nunchaku expects 2D input only, so we reshape 3D to 2D
|
||||||
|
original_shape = x.shape
|
||||||
|
if len(original_shape) == 2:
|
||||||
|
# (batch_size, in_features)
|
||||||
|
batch_size = original_shape[0]
|
||||||
|
x_2d = x
|
||||||
|
#elif len(original_shape) == 3:
|
||||||
|
# # (batch_size, seq_len, in_features) -> (batch_size * seq_len, in_features)
|
||||||
|
# batch_size, seq_len, _ = original_shape
|
||||||
|
# x_2d = x.reshape(batch_size * seq_len, in_features)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"AWQ linear expects 2D or 3D input, got {len(original_shape)}D")
|
||||||
|
|
||||||
|
# Ensure input is contiguous (required by CUDA kernel)
|
||||||
|
# Only create a contiguous copy if absolutely necessary
|
||||||
|
#if not x_2d.is_contiguous():
|
||||||
|
# x_2d = x_2d.contiguous()
|
||||||
|
|
||||||
|
output = awq_gemv_w4a16_cuda(
|
||||||
|
in_feats=x_2d,
|
||||||
|
kernel=qweight,
|
||||||
|
scaling_factors=wscales,
|
||||||
|
zeros=wzeros,
|
||||||
|
m=x_2d.shape[0],
|
||||||
|
n=out_features,
|
||||||
|
k=in_features,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add bias if present
|
||||||
|
if bias is not None:
|
||||||
|
view_shape = [1] * (output.ndim - 1) + [-1]
|
||||||
|
output = output + bias.view(view_shape)
|
||||||
|
|
||||||
|
# Reshape back to original batch dimensions
|
||||||
|
#if len(original_shape) == 3:
|
||||||
|
# output = output.view(batch_size, seq_len, out_features)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
LAYOUTS["SVDQuantLayout"] = SVDQuantLayout
|
||||||
|
LAYOUTS["AWQQuantLayout"] = AWQQuantLayout
|
||||||
377
comfy/svdquant_converter.py
Normal file
377
comfy/svdquant_converter.py
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Fused layer splitting is no longer used
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConvertedState:
|
||||||
|
tensors: Dict[str, torch.Tensor]
|
||||||
|
quant_layers: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_svd_prefix(keys: set[str], prefix: str) -> bool:
|
||||||
|
return (
|
||||||
|
f"{prefix}.qweight" in keys
|
||||||
|
and f"{prefix}.smooth_factor" in keys
|
||||||
|
and f"{prefix}.proj_down" in keys
|
||||||
|
and f"{prefix}.proj_up" in keys
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_awq_prefix(keys: set[str], prefix: str) -> bool:
|
||||||
|
return (
|
||||||
|
f"{prefix}.qweight" in keys
|
||||||
|
and f"{prefix}.wscales" in keys
|
||||||
|
and f"{prefix}.wzeros" in keys
|
||||||
|
and f"{prefix}.smooth_factor" not in keys # Distinguish from SVDQuant
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_svd_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
prefixes = set()
|
||||||
|
keys = set(state_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
if not key.endswith(".qweight"):
|
||||||
|
continue
|
||||||
|
prefix = key[: -len(".qweight")]
|
||||||
|
if _is_svd_prefix(keys, prefix):
|
||||||
|
prefixes.add(prefix)
|
||||||
|
return sorted(prefixes)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_awq_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
prefixes = set()
|
||||||
|
keys = set(state_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
if not key.endswith(".qweight"):
|
||||||
|
continue
|
||||||
|
prefix = key[: -len(".qweight")]
|
||||||
|
if _is_awq_prefix(keys, prefix):
|
||||||
|
prefixes.add(prefix)
|
||||||
|
return sorted(prefixes)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_format(wscales: torch.Tensor) -> str:
|
||||||
|
if wscales.dtype == torch.float8_e4m3fn:
|
||||||
|
return "svdquant_nvfp4"
|
||||||
|
return "svdquant_int4"
|
||||||
|
|
||||||
|
|
||||||
|
class _SVDQuantConverter:
|
||||||
|
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
|
self.src = dict(state_dict)
|
||||||
|
self.dst: Dict[str, torch.Tensor] = {}
|
||||||
|
self.quant_layers: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def convert(self) -> ConvertedState:
|
||||||
|
prefixes = _detect_svd_prefixes(self.src)
|
||||||
|
for prefix in prefixes:
|
||||||
|
self._convert_single(prefix)
|
||||||
|
|
||||||
|
for key, tensor in self.src.items():
|
||||||
|
if key not in self.dst:
|
||||||
|
self.dst[key] = tensor
|
||||||
|
|
||||||
|
return ConvertedState(self.dst, self.quant_layers)
|
||||||
|
|
||||||
|
def _pop_tensor(self, key: str) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
return self.src.pop(key)
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError(f"Missing key '{key}' in SVDQuant checkpoint") from exc
|
||||||
|
|
||||||
|
def _pop_optional(self, key: str) -> torch.Tensor | None:
|
||||||
|
return self.src.pop(key, None)
|
||||||
|
|
||||||
|
def _convert_single(self, prefix: str) -> None:
|
||||||
|
# Ensure all tensors are contiguous to avoid CUDA alignment issues
|
||||||
|
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
|
||||||
|
wscales = self._pop_tensor(f"{prefix}.wscales").contiguous()
|
||||||
|
self.dst[f"{prefix}.wscales"] = wscales
|
||||||
|
format_name = _detect_format(wscales)
|
||||||
|
|
||||||
|
self.dst[f"{prefix}.smooth_factor"] = self._pop_tensor(f"{prefix}.smooth_factor").contiguous()
|
||||||
|
self.dst[f"{prefix}.smooth_factor_orig"] = self._pop_tensor(
|
||||||
|
f"{prefix}.smooth_factor_orig"
|
||||||
|
).contiguous()
|
||||||
|
self.dst[f"{prefix}.proj_down"] = self._pop_tensor(f"{prefix}.proj_down").contiguous()
|
||||||
|
self.dst[f"{prefix}.proj_up"] = self._pop_tensor(f"{prefix}.proj_up").contiguous()
|
||||||
|
|
||||||
|
bias = self._pop_optional(f"{prefix}.bias")
|
||||||
|
if bias is not None:
|
||||||
|
self.dst[f"{prefix}.bias"] = bias.contiguous()
|
||||||
|
|
||||||
|
wtscale = self._pop_optional(f"{prefix}.wtscale")
|
||||||
|
if wtscale is not None:
|
||||||
|
self.dst[f"{prefix}.wtscale"] = wtscale.contiguous() if isinstance(wtscale, torch.Tensor) else wtscale
|
||||||
|
|
||||||
|
wcscales = self._pop_optional(f"{prefix}.wcscales")
|
||||||
|
if wcscales is not None:
|
||||||
|
self.dst[f"{prefix}.wcscales"] = wcscales.contiguous()
|
||||||
|
|
||||||
|
self.quant_layers[prefix] = format_name
|
||||||
|
|
||||||
|
|
||||||
|
class _AWQConverter:
|
||||||
|
def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
|
self.src = dict(state_dict)
|
||||||
|
self.dst: Dict[str, torch.Tensor] = {}
|
||||||
|
self.quant_layers: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def convert(self) -> ConvertedState:
|
||||||
|
prefixes = _detect_awq_prefixes(self.src)
|
||||||
|
for prefix in prefixes:
|
||||||
|
self._convert_single(prefix)
|
||||||
|
|
||||||
|
for key, tensor in self.src.items():
|
||||||
|
if key not in self.dst:
|
||||||
|
self.dst[key] = tensor
|
||||||
|
|
||||||
|
return ConvertedState(self.dst, self.quant_layers)
|
||||||
|
|
||||||
|
def _pop_tensor(self, key: str) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
return self.src.pop(key)
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError(f"Missing key '{key}' in AWQ checkpoint") from exc
|
||||||
|
|
||||||
|
def _pop_optional(self, key: str) -> torch.Tensor | None:
|
||||||
|
return self.src.pop(key, None)
|
||||||
|
|
||||||
|
def _convert_single(self, prefix: str) -> None:
|
||||||
|
# Ensure all tensors are contiguous to avoid CUDA alignment issues
|
||||||
|
self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous()
|
||||||
|
self.dst[f"{prefix}.wscales"] = self._pop_tensor(f"{prefix}.wscales").contiguous()
|
||||||
|
self.dst[f"{prefix}.wzeros"] = self._pop_tensor(f"{prefix}.wzeros").contiguous()
|
||||||
|
|
||||||
|
bias = self._pop_optional(f"{prefix}.bias")
|
||||||
|
if bias is not None:
|
||||||
|
self.dst[f"{prefix}.bias"] = bias.contiguous()
|
||||||
|
|
||||||
|
self.quant_layers[prefix] = "awq_int4"
|
||||||
|
|
||||||
|
|
||||||
|
def convert_svdquant_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
|
||||||
|
return _SVDQuantConverter(state_dict).convert()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_awq_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState:
|
||||||
|
return _AWQConverter(state_dict).convert()
|
||||||
|
|
||||||
|
|
||||||
|
def detect_quantization_formats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Detect quantization formats present in a state dict.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state_dict : Dict[str, torch.Tensor]
|
||||||
|
State dictionary to analyze
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, List[str]]
|
||||||
|
Dictionary mapping format names to lists of layer prefixes
|
||||||
|
Example: {
|
||||||
|
"svdquant_int4": ["layer1.attn.qkv", "layer2.mlp.up"],
|
||||||
|
"svdquant_nvfp4": ["layer3.attn.qkv"],
|
||||||
|
"awq_int4": ["layer1.mlp.down", "layer4.attn.qkv"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Detect SVDQuant layers
|
||||||
|
svd_prefixes = _detect_svd_prefixes(state_dict)
|
||||||
|
if svd_prefixes:
|
||||||
|
# Determine if int4 or nvfp4 based on wscales dtype
|
||||||
|
for prefix in svd_prefixes:
|
||||||
|
wscales_key = f"{prefix}.wscales"
|
||||||
|
if wscales_key in state_dict:
|
||||||
|
format_name = _detect_format(state_dict[wscales_key])
|
||||||
|
if format_name not in result:
|
||||||
|
result[format_name] = []
|
||||||
|
result[format_name].append(prefix)
|
||||||
|
|
||||||
|
# Detect AWQ layers
|
||||||
|
awq_prefixes = _detect_awq_prefixes(state_dict)
|
||||||
|
if awq_prefixes:
|
||||||
|
result["awq_int4"] = awq_prefixes
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def convert_awq_file(
|
||||||
|
input_path: str,
|
||||||
|
output_path: str,
|
||||||
|
format_version: str = "1.0",
|
||||||
|
) -> Tuple[int, Dict[str, str]]:
|
||||||
|
with safe_open(input_path, framework="pt", device="cpu") as f:
|
||||||
|
tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
metadata = dict(f.metadata())
|
||||||
|
|
||||||
|
converted = convert_awq_state_dict(tensors)
|
||||||
|
|
||||||
|
# Convert layer format dict to expected metadata format
|
||||||
|
# From: {"layer": "awq_int4"}
|
||||||
|
# To: {"layer": {"format": "awq_int4"}}
|
||||||
|
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
|
||||||
|
|
||||||
|
metadata["_quantization_metadata"] = json.dumps(
|
||||||
|
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
|
||||||
|
)
|
||||||
|
|
||||||
|
save_file(converted.tensors, output_path, metadata=metadata)
|
||||||
|
return len(converted.quant_layers), converted.quant_layers
|
||||||
|
|
||||||
|
|
||||||
|
def convert_svdquant_file(
|
||||||
|
input_path: str,
|
||||||
|
output_path: str,
|
||||||
|
format_version: str = "1.0",
|
||||||
|
) -> Tuple[int, Dict[str, str]]:
|
||||||
|
with safe_open(input_path, framework="pt", device="cpu") as f:
|
||||||
|
tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
metadata = dict(f.metadata())
|
||||||
|
|
||||||
|
converted = convert_svdquant_state_dict(tensors)
|
||||||
|
|
||||||
|
# Convert layer format dict to expected metadata format
|
||||||
|
# From: {"layer": "svdquant_int4"}
|
||||||
|
# To: {"layer": {"format": "svdquant_int4"}}
|
||||||
|
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
|
||||||
|
|
||||||
|
metadata["_quantization_metadata"] = json.dumps(
|
||||||
|
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
|
||||||
|
)
|
||||||
|
metadata["model_class"] = "QwenImageTransformer2DModel"
|
||||||
|
|
||||||
|
save_file(converted.tensors, output_path, metadata=metadata)
|
||||||
|
return len(converted.quant_layers), converted.quant_layers
|
||||||
|
|
||||||
|
|
||||||
|
def convert_quantized_file(
|
||||||
|
input_path: str,
|
||||||
|
output_path: str,
|
||||||
|
format_version: str = "1.0",
|
||||||
|
quant_format: str = "auto",
|
||||||
|
) -> Tuple[int, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Auto-detect and convert quantized checkpoint to ComfyUI format.
|
||||||
|
|
||||||
|
Supports mixed-format models where some layers are SVDQuant and others are AWQ.
|
||||||
|
Each layer is independently detected and converted to the appropriate format.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_path : str
|
||||||
|
Path to input checkpoint file
|
||||||
|
output_path : str
|
||||||
|
Path to output checkpoint file
|
||||||
|
format_version : str, optional
|
||||||
|
Quantization metadata format version (default: "1.0")
|
||||||
|
quant_format : str, optional
|
||||||
|
Quantization format: "auto", "svdquant", or "awq" (default: "auto")
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[int, Dict[str, str]]
|
||||||
|
Number of quantized layers and mapping of layer names to formats
|
||||||
|
"""
|
||||||
|
with safe_open(input_path, framework="pt", device="cpu") as f:
|
||||||
|
tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
metadata = dict(f.metadata())
|
||||||
|
|
||||||
|
# Auto-detect format if needed
|
||||||
|
if quant_format == "auto":
|
||||||
|
svd_prefixes = _detect_svd_prefixes(tensors)
|
||||||
|
awq_prefixes = _detect_awq_prefixes(tensors)
|
||||||
|
|
||||||
|
if svd_prefixes and awq_prefixes:
|
||||||
|
# Mixed format - partition tensors by format and convert separately
|
||||||
|
|
||||||
|
# Build sets of all quantized prefixes
|
||||||
|
all_svd_prefixes = set(svd_prefixes)
|
||||||
|
all_awq_prefixes = set(awq_prefixes)
|
||||||
|
|
||||||
|
# Helper to check if a key belongs to a specific quantized layer
|
||||||
|
def belongs_to_prefix(key, prefix):
|
||||||
|
"""Check if key belongs to a specific layer prefix."""
|
||||||
|
return key == prefix or key.startswith(f"{prefix}.")
|
||||||
|
|
||||||
|
def is_svd_key(key):
|
||||||
|
"""Check if key belongs to any SVDQuant layer."""
|
||||||
|
return any(belongs_to_prefix(key, prefix) for prefix in all_svd_prefixes)
|
||||||
|
|
||||||
|
def is_awq_key(key):
|
||||||
|
"""Check if key belongs to any AWQ layer."""
|
||||||
|
return any(belongs_to_prefix(key, prefix) for prefix in all_awq_prefixes)
|
||||||
|
|
||||||
|
# Partition tensors by format
|
||||||
|
svd_tensors = {}
|
||||||
|
awq_tensors = {}
|
||||||
|
other_tensors = {}
|
||||||
|
|
||||||
|
for key, tensor in tensors.items():
|
||||||
|
if is_svd_key(key):
|
||||||
|
svd_tensors[key] = tensor
|
||||||
|
elif is_awq_key(key):
|
||||||
|
awq_tensors[key] = tensor
|
||||||
|
else:
|
||||||
|
other_tensors[key] = tensor
|
||||||
|
|
||||||
|
# Convert each format separately with only its relevant tensors
|
||||||
|
svd_converted = _SVDQuantConverter(svd_tensors).convert()
|
||||||
|
awq_converted = _AWQConverter(awq_tensors).convert()
|
||||||
|
|
||||||
|
# Merge results - each converter only has its own layer tensors
|
||||||
|
converted_tensors = {}
|
||||||
|
|
||||||
|
# Add SVDQuant converted tensors
|
||||||
|
converted_tensors.update(svd_converted.tensors)
|
||||||
|
|
||||||
|
# Add AWQ converted tensors
|
||||||
|
converted_tensors.update(awq_converted.tensors)
|
||||||
|
|
||||||
|
# Add non-quantized tensors
|
||||||
|
converted_tensors.update(other_tensors)
|
||||||
|
|
||||||
|
# Merge quantization layer metadata
|
||||||
|
quant_layers = {}
|
||||||
|
quant_layers.update(svd_converted.quant_layers)
|
||||||
|
quant_layers.update(awq_converted.quant_layers)
|
||||||
|
|
||||||
|
converted = ConvertedState(converted_tensors, quant_layers)
|
||||||
|
elif svd_prefixes:
|
||||||
|
converted = convert_svdquant_state_dict(tensors)
|
||||||
|
elif awq_prefixes:
|
||||||
|
converted = convert_awq_state_dict(tensors)
|
||||||
|
else:
|
||||||
|
raise ValueError("No quantized layers detected in checkpoint")
|
||||||
|
elif quant_format == "svdquant":
|
||||||
|
converted = convert_svdquant_state_dict(tensors)
|
||||||
|
elif quant_format == "awq":
|
||||||
|
converted = convert_awq_state_dict(tensors)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown quantization format: {quant_format}")
|
||||||
|
|
||||||
|
# Convert layer format dict to expected metadata format
|
||||||
|
# From: {"layer": "awq_int4"}
|
||||||
|
# To: {"layer": {"format": "awq_int4"}}
|
||||||
|
layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()}
|
||||||
|
|
||||||
|
metadata["_quantization_metadata"] = json.dumps(
|
||||||
|
{"format_version": format_version, "layers": layers_metadata}, sort_keys=True
|
||||||
|
)
|
||||||
|
metadata["model_class"] = "QwenImageTransformer2DModel"
|
||||||
|
|
||||||
|
save_file(converted.tensors, output_path, metadata=metadata)
|
||||||
|
return len(converted.quant_layers), converted.quant_layers
|
||||||
|
|
||||||
|
|
||||||
116
convert_svdquant_checkpoint.py
Normal file
116
convert_svdquant_checkpoint.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Convert quantized checkpoints (SVDQuant, AWQ, or mixed) into the ComfyUI quantization format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
from comfy.svdquant_converter import (
|
||||||
|
convert_quantized_file,
|
||||||
|
convert_svdquant_file,
|
||||||
|
convert_awq_file,
|
||||||
|
detect_quantization_formats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert quantized .safetensors files (SVDQuant, AWQ, or mixed) "
|
||||||
|
"into the ComfyUI format with per-layer metadata for MixedPrecisionOps."
|
||||||
|
)
|
||||||
|
parser.add_argument("input", type=Path, help="Path to the source quantized .safetensors file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output",
|
||||||
|
type=Path,
|
||||||
|
help="Destination path for the converted checkpoint. "
|
||||||
|
"Defaults to <input_name>_comfy.safetensors in the same directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--format-version",
|
||||||
|
default="1.0",
|
||||||
|
help="Format version to store inside _quantization_metadata (default: 1.0).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
choices=["auto", "svdquant", "awq"],
|
||||||
|
default="auto",
|
||||||
|
help="Quantization format (default: auto-detect).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--detect-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only detect and report quantization formats without converting.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = _build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_path = args.input.expanduser().resolve()
|
||||||
|
|
||||||
|
# Detect formats if requested
|
||||||
|
if args.detect_only:
|
||||||
|
print(f"[Quantization Detector] Analyzing: {input_path}")
|
||||||
|
with safe_open(str(input_path), framework="pt", device="cpu") as f:
|
||||||
|
tensors = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
|
||||||
|
formats = detect_quantization_formats(tensors)
|
||||||
|
|
||||||
|
if not formats:
|
||||||
|
print("[Quantization Detector] No quantized layers detected.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"[Quantization Detector] Detected formats:")
|
||||||
|
total_layers = 0
|
||||||
|
for format_name, layer_prefixes in sorted(formats.items()):
|
||||||
|
print(f"\n {format_name}: {len(layer_prefixes)} layers")
|
||||||
|
for prefix in sorted(layer_prefixes)[:5]: # Show first 5
|
||||||
|
print(f" - {prefix}")
|
||||||
|
if len(layer_prefixes) > 5:
|
||||||
|
print(f" ... and {len(layer_prefixes) - 5} more")
|
||||||
|
total_layers += len(layer_prefixes)
|
||||||
|
|
||||||
|
print(f"\n[Quantization Detector] Total: {total_layers} quantized layers")
|
||||||
|
print(f"[Quantization Detector] Use without --detect-only to convert.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert checkpoint
|
||||||
|
if args.output is None:
|
||||||
|
output_path = input_path.with_name(f"{input_path.stem}_comfy.safetensors")
|
||||||
|
else:
|
||||||
|
output_path = args.output.expanduser().resolve()
|
||||||
|
|
||||||
|
layer_count, quant_layers = convert_quantized_file(
|
||||||
|
str(input_path),
|
||||||
|
str(output_path),
|
||||||
|
format_version=args.format_version,
|
||||||
|
quant_format=args.format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group layers by format for display
|
||||||
|
format_groups = {}
|
||||||
|
for layer_name, fmt in quant_layers.items():
|
||||||
|
if fmt not in format_groups:
|
||||||
|
format_groups[fmt] = []
|
||||||
|
format_groups[fmt].append(layer_name)
|
||||||
|
|
||||||
|
print(f"[Quantization Converter] Converted {layer_count} layers.")
|
||||||
|
print(f"[Quantization Converter] Output saved to: {output_path}")
|
||||||
|
print(f"\n[Quantization Converter] Quantized layers by format:")
|
||||||
|
|
||||||
|
for fmt, layers in sorted(format_groups.items()):
|
||||||
|
print(f"\n {fmt}: {len(layers)} layers")
|
||||||
|
for layer_name in sorted(layers)[:5]: # Show first 5
|
||||||
|
print(f" - {layer_name}")
|
||||||
|
if len(layers) > 5:
|
||||||
|
print(f" ... and {len(layers) - 5} more")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
@ -1,7 +1,10 @@
|
|||||||
import unittest
|
|
||||||
import torch
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
@ -13,7 +16,9 @@ from comfy.cli_args import args
|
|||||||
if not has_gpu():
|
if not has_gpu():
|
||||||
args.cpu = True
|
args.cpu = True
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout
|
||||||
|
from comfy.ops import mixed_precision_ops
|
||||||
|
from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedTensor(unittest.TestCase):
|
class TestQuantizedTensor(unittest.TestCase):
|
||||||
@ -156,6 +161,199 @@ class TestTensorCoreFP8Layout(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestAWQQuantLayout(unittest.TestCase):
|
||||||
|
"""Test the AWQQuantLayout implementation"""
|
||||||
|
|
||||||
|
def test_awq_layout_creation(self):
|
||||||
|
"""Test creating an AWQ quantized tensor"""
|
||||||
|
# AWQ uses pre-quantized weights loaded from checkpoints
|
||||||
|
# Create dummy AWQ quantized weights
|
||||||
|
out_features, in_features = 256, 128
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'wzeros': wzeros,
|
||||||
|
'group_size': group_size,
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.shape, qweight.shape)
|
||||||
|
self.assertEqual(qt.dtype, torch.int32)
|
||||||
|
self.assertEqual(qt._layout_type, "AWQQuantLayout")
|
||||||
|
self.assertEqual(qt._layout_params['group_size'], group_size)
|
||||||
|
|
||||||
|
def test_awq_quantize_not_supported(self):
|
||||||
|
"""Test that online quantization raises NotImplementedError for AWQ"""
|
||||||
|
# AWQ doesn't support online quantization - weights must be pre-quantized
|
||||||
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||||
|
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
AWQQuantLayout.quantize(float_tensor, is_weight=True)
|
||||||
|
|
||||||
|
def test_awq_get_plain_tensors(self):
|
||||||
|
"""Test extracting plain tensors from AWQ quantized tensor"""
|
||||||
|
out_features, in_features = 256, 128
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'wzeros': wzeros,
|
||||||
|
'group_size': group_size,
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
||||||
|
plain_tensors = AWQQuantLayout.get_plain_tensors(qt)
|
||||||
|
|
||||||
|
# Verify we can extract all necessary components
|
||||||
|
self.assertIsInstance(plain_tensors, dict)
|
||||||
|
self.assertIn('qweight', plain_tensors)
|
||||||
|
self.assertIn('wscales', plain_tensors)
|
||||||
|
self.assertIn('wzeros', plain_tensors)
|
||||||
|
self.assertIn('group_size', plain_tensors)
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSVDQuantLayout(unittest.TestCase):
|
||||||
|
"""Test the SVDQuantLayout implementation"""
|
||||||
|
|
||||||
|
def test_svdquant_layout_creation(self):
|
||||||
|
"""Test creating an SVDQuant quantized tensor"""
|
||||||
|
# SVDQuant uses pre-quantized weights loaded from checkpoints
|
||||||
|
out_features, in_features = 256, 128
|
||||||
|
rank = 32
|
||||||
|
group_size = 64
|
||||||
|
precision = "int4"
|
||||||
|
|
||||||
|
# Create dummy SVDQuant quantized weights (int8 range is -128 to 127)
|
||||||
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
||||||
|
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
|
||||||
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
||||||
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'smooth_factor': smooth_factor,
|
||||||
|
'smooth_factor_orig': smooth_factor_orig,
|
||||||
|
'proj_down': proj_down,
|
||||||
|
'proj_up': proj_up,
|
||||||
|
'group_size': group_size,
|
||||||
|
'precision': precision,
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
'act_unsigned': False,
|
||||||
|
'wtscale': None,
|
||||||
|
'wcscales': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.shape, qweight.shape)
|
||||||
|
self.assertEqual(qt.dtype, torch.int8)
|
||||||
|
self.assertEqual(qt._layout_type, "SVDQuantLayout")
|
||||||
|
self.assertEqual(qt._layout_params['group_size'], group_size)
|
||||||
|
self.assertEqual(qt._layout_params['precision'], precision)
|
||||||
|
|
||||||
|
def test_svdquant_quantize_not_supported(self):
|
||||||
|
"""Test that online quantization raises NotImplementedError for SVDQuant"""
|
||||||
|
# SVDQuant doesn't support online quantization - weights must be pre-quantized
|
||||||
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||||
|
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
SVDQuantLayout.quantize(float_tensor, is_weight=True)
|
||||||
|
|
||||||
|
def test_svdquant_dequantize_not_supported(self):
|
||||||
|
"""Test that weight dequantization raises NotImplementedError for SVDQuant"""
|
||||||
|
# Full weight dequantization is not supported (complex operation)
|
||||||
|
out_features, in_features = 256, 128
|
||||||
|
rank = 32
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
||||||
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
||||||
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
SVDQuantLayout.dequantize(
|
||||||
|
qweight,
|
||||||
|
is_weight=True,
|
||||||
|
wscales=wscales,
|
||||||
|
smooth_factor=smooth_factor,
|
||||||
|
proj_down=proj_down,
|
||||||
|
proj_up=proj_up,
|
||||||
|
group_size=group_size,
|
||||||
|
precision="int4",
|
||||||
|
orig_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_svdquant_get_plain_tensors(self):
|
||||||
|
"""Test extracting plain tensors from SVDQuant quantized tensor"""
|
||||||
|
out_features, in_features = 256, 128
|
||||||
|
rank = 32
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
||||||
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
||||||
|
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
|
||||||
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
||||||
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'smooth_factor': smooth_factor,
|
||||||
|
'smooth_factor_orig': smooth_factor_orig,
|
||||||
|
'proj_down': proj_down,
|
||||||
|
'proj_up': proj_up,
|
||||||
|
'group_size': group_size,
|
||||||
|
'precision': "int4",
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
'act_unsigned': False,
|
||||||
|
'wtscale': None,
|
||||||
|
'wcscales': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
|
||||||
|
plain_tensors = SVDQuantLayout.get_plain_tensors(qt)
|
||||||
|
|
||||||
|
# Verify we can extract all necessary components
|
||||||
|
self.assertIsInstance(plain_tensors, dict)
|
||||||
|
self.assertIn('qweight', plain_tensors)
|
||||||
|
self.assertIn('wscales', plain_tensors)
|
||||||
|
self.assertIn('smooth_factor', plain_tensors)
|
||||||
|
self.assertIn('proj_down', plain_tensors)
|
||||||
|
self.assertIn('proj_up', plain_tensors)
|
||||||
|
self.assertIn('group_size', plain_tensors)
|
||||||
|
self.assertIn('precision', plain_tensors)
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down))
|
||||||
|
self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up))
|
||||||
|
|
||||||
|
|
||||||
class TestFallbackMechanism(unittest.TestCase):
|
class TestFallbackMechanism(unittest.TestCase):
|
||||||
"""Test fallback for unsupported operations"""
|
"""Test fallback for unsupported operations"""
|
||||||
|
|
||||||
@ -186,5 +384,158 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAWQConversion(unittest.TestCase):
|
||||||
|
"""Test AWQ checkpoint conversion"""
|
||||||
|
|
||||||
|
def test_awq_single_layer_conversion(self):
|
||||||
|
"""Test converting a single AWQ layer"""
|
||||||
|
in_features, out_features = 128, 256
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
# Create AWQ checkpoint format
|
||||||
|
state_dict = {
|
||||||
|
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
|
||||||
|
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
||||||
|
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
||||||
|
"layer.bias": torch.randn(out_features, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
converted = convert_awq_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Check that qweight was renamed to weight
|
||||||
|
self.assertIn("layer.weight", converted.tensors)
|
||||||
|
self.assertNotIn("layer.qweight", converted.tensors)
|
||||||
|
|
||||||
|
# Check other parameters preserved
|
||||||
|
self.assertIn("layer.wscales", converted.tensors)
|
||||||
|
self.assertIn("layer.wzeros", converted.tensors)
|
||||||
|
self.assertIn("layer.bias", converted.tensors)
|
||||||
|
|
||||||
|
# Check quantization metadata
|
||||||
|
self.assertIn("layer", converted.quant_layers)
|
||||||
|
self.assertEqual(converted.quant_layers["layer"], "awq_int4")
|
||||||
|
|
||||||
|
def test_awq_tensor_shapes(self):
|
||||||
|
"""Test that converted AWQ tensors have correct shapes"""
|
||||||
|
in_features, out_features = 3072, 18432
|
||||||
|
group_size = 64
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
|
||||||
|
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
||||||
|
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
converted = convert_awq_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Check qweight shape (packed 4-bit)
|
||||||
|
qweight = converted.tensors["layer.weight"]
|
||||||
|
self.assertEqual(qweight.shape, (out_features // 4, in_features // 2))
|
||||||
|
self.assertEqual(qweight.dtype, torch.int32)
|
||||||
|
|
||||||
|
# Check wscales shape
|
||||||
|
wscales = converted.tensors["layer.wscales"]
|
||||||
|
self.assertEqual(wscales.shape, (in_features // group_size, out_features))
|
||||||
|
self.assertEqual(wscales.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
# Check wzeros shape
|
||||||
|
wzeros = converted.tensors["layer.wzeros"]
|
||||||
|
self.assertEqual(wzeros.shape, (in_features // group_size, out_features))
|
||||||
|
self.assertEqual(wzeros.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAWQLinearOperation(unittest.TestCase):
|
||||||
|
"""Test AWQ linear operations with actual nunchaku kernels"""
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
|
||||||
|
def test_awq_linear_basic(self):
|
||||||
|
"""Test basic AWQ linear operation by calling kernel directly"""
|
||||||
|
try:
|
||||||
|
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
|
||||||
|
except ImportError:
|
||||||
|
self.skipTest("nunchaku package not available")
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
in_features, out_features = 128, 256
|
||||||
|
group_size = 64
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
# Create AWQ quantized weight tensors
|
||||||
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
||||||
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
||||||
|
bias = torch.randn(out_features, dtype=torch.bfloat16, device=device)
|
||||||
|
|
||||||
|
# Create layout params
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'wzeros': wzeros,
|
||||||
|
'group_size': group_size,
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
||||||
|
|
||||||
|
# Check that weight is a QuantizedTensor
|
||||||
|
self.assertIsInstance(weight, QuantizedTensor)
|
||||||
|
self.assertEqual(weight._layout_type, "AWQQuantLayout")
|
||||||
|
|
||||||
|
# Create input
|
||||||
|
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
|
||||||
|
|
||||||
|
# Call AWQ linear handler directly
|
||||||
|
from comfy.quant_ops import awq_linear
|
||||||
|
output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {})
|
||||||
|
|
||||||
|
# Check output shape and dtype
|
||||||
|
self.assertEqual(output.shape, (batch_size, out_features))
|
||||||
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
|
||||||
|
def test_awq_linear_2d_input(self):
|
||||||
|
"""Test AWQ linear with 2D input (batch, features) by calling kernel directly"""
|
||||||
|
try:
|
||||||
|
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
|
||||||
|
except ImportError:
|
||||||
|
self.skipTest("nunchaku package not available")
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
in_features, out_features = 128, 256
|
||||||
|
group_size = 64
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
# Create AWQ quantized weight tensors
|
||||||
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
|
||||||
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
||||||
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
||||||
|
|
||||||
|
# Create layout params
|
||||||
|
layout_params = {
|
||||||
|
'wscales': wscales,
|
||||||
|
'wzeros': wzeros,
|
||||||
|
'group_size': group_size,
|
||||||
|
'orig_dtype': torch.bfloat16,
|
||||||
|
'is_weight': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
||||||
|
|
||||||
|
# Check that weight is a QuantizedTensor
|
||||||
|
self.assertIsInstance(weight, QuantizedTensor)
|
||||||
|
self.assertEqual(weight._layout_type, "AWQQuantLayout")
|
||||||
|
|
||||||
|
# Create 2D input
|
||||||
|
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
|
||||||
|
|
||||||
|
# Call AWQ linear handler directly
|
||||||
|
from comfy.quant_ops import awq_linear
|
||||||
|
output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {})
|
||||||
|
|
||||||
|
# Check output shape
|
||||||
|
self.assertEqual(output.shape, (batch_size, out_features))
|
||||||
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user