import json from dataclasses import dataclass from typing import Dict, List, Tuple import torch from safetensors import safe_open from safetensors.torch import save_file # Note: Fused layer splitting is no longer used @dataclass class ConvertedState: tensors: Dict[str, torch.Tensor] quant_layers: Dict[str, str] def _is_svd_prefix(keys: set[str], prefix: str) -> bool: return ( f"{prefix}.qweight" in keys and f"{prefix}.smooth_factor" in keys and f"{prefix}.proj_down" in keys and f"{prefix}.proj_up" in keys ) def _is_awq_prefix(keys: set[str], prefix: str) -> bool: return ( f"{prefix}.qweight" in keys and f"{prefix}.wscales" in keys and f"{prefix}.wzeros" in keys and f"{prefix}.smooth_factor" not in keys # Distinguish from SVDQuant ) def _detect_svd_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]: prefixes = set() keys = set(state_dict.keys()) for key in keys: if not key.endswith(".qweight"): continue prefix = key[: -len(".qweight")] if _is_svd_prefix(keys, prefix): prefixes.add(prefix) return sorted(prefixes) def _detect_awq_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]: prefixes = set() keys = set(state_dict.keys()) for key in keys: if not key.endswith(".qweight"): continue prefix = key[: -len(".qweight")] if _is_awq_prefix(keys, prefix): prefixes.add(prefix) return sorted(prefixes) def _detect_format(wscales: torch.Tensor) -> str: if wscales.dtype == torch.float8_e4m3fn: return "svdquant_nvfp4" return "svdquant_int4" class _SVDQuantConverter: def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None: self.src = dict(state_dict) self.dst: Dict[str, torch.Tensor] = {} self.quant_layers: Dict[str, str] = {} def convert(self) -> ConvertedState: prefixes = _detect_svd_prefixes(self.src) for prefix in prefixes: self._convert_single(prefix) for key, tensor in self.src.items(): if key not in self.dst: self.dst[key] = tensor return ConvertedState(self.dst, self.quant_layers) def _pop_tensor(self, key: str) -> torch.Tensor: try: return self.src.pop(key) except KeyError as exc: raise KeyError(f"Missing key '{key}' in SVDQuant checkpoint") from exc def _pop_optional(self, key: str) -> torch.Tensor | None: return self.src.pop(key, None) def _convert_single(self, prefix: str) -> None: # Ensure all tensors are contiguous to avoid CUDA alignment issues self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous() wscales = self._pop_tensor(f"{prefix}.wscales").contiguous() self.dst[f"{prefix}.wscales"] = wscales format_name = _detect_format(wscales) self.dst[f"{prefix}.smooth_factor"] = self._pop_tensor(f"{prefix}.smooth_factor").contiguous() self.dst[f"{prefix}.smooth_factor_orig"] = self._pop_tensor( f"{prefix}.smooth_factor_orig" ).contiguous() self.dst[f"{prefix}.proj_down"] = self._pop_tensor(f"{prefix}.proj_down").contiguous() self.dst[f"{prefix}.proj_up"] = self._pop_tensor(f"{prefix}.proj_up").contiguous() bias = self._pop_optional(f"{prefix}.bias") if bias is not None: self.dst[f"{prefix}.bias"] = bias.contiguous() wtscale = self._pop_optional(f"{prefix}.wtscale") if wtscale is not None: self.dst[f"{prefix}.wtscale"] = wtscale.contiguous() if isinstance(wtscale, torch.Tensor) else wtscale wcscales = self._pop_optional(f"{prefix}.wcscales") if wcscales is not None: self.dst[f"{prefix}.wcscales"] = wcscales.contiguous() self.quant_layers[prefix] = format_name class _AWQConverter: def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None: self.src = dict(state_dict) self.dst: Dict[str, torch.Tensor] = {} self.quant_layers: Dict[str, str] = {} def convert(self) -> ConvertedState: prefixes = _detect_awq_prefixes(self.src) for prefix in prefixes: self._convert_single(prefix) for key, tensor in self.src.items(): if key not in self.dst: self.dst[key] = tensor return ConvertedState(self.dst, self.quant_layers) def _pop_tensor(self, key: str) -> torch.Tensor: try: return self.src.pop(key) except KeyError as exc: raise KeyError(f"Missing key '{key}' in AWQ checkpoint") from exc def _pop_optional(self, key: str) -> torch.Tensor | None: return self.src.pop(key, None) def _convert_single(self, prefix: str) -> None: # Ensure all tensors are contiguous to avoid CUDA alignment issues self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous() self.dst[f"{prefix}.wscales"] = self._pop_tensor(f"{prefix}.wscales").contiguous() self.dst[f"{prefix}.wzeros"] = self._pop_tensor(f"{prefix}.wzeros").contiguous() bias = self._pop_optional(f"{prefix}.bias") if bias is not None: self.dst[f"{prefix}.bias"] = bias.contiguous() self.quant_layers[prefix] = "awq_int4" def convert_svdquant_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState: return _SVDQuantConverter(state_dict).convert() def convert_awq_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState: return _AWQConverter(state_dict).convert() def detect_quantization_formats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: """ Detect quantization formats present in a state dict. Parameters ---------- state_dict : Dict[str, torch.Tensor] State dictionary to analyze Returns ------- Dict[str, List[str]] Dictionary mapping format names to lists of layer prefixes Example: { "svdquant_int4": ["layer1.attn.qkv", "layer2.mlp.up"], "svdquant_nvfp4": ["layer3.attn.qkv"], "awq_int4": ["layer1.mlp.down", "layer4.attn.qkv"] } """ result = {} # Detect SVDQuant layers svd_prefixes = _detect_svd_prefixes(state_dict) if svd_prefixes: # Determine if int4 or nvfp4 based on wscales dtype for prefix in svd_prefixes: wscales_key = f"{prefix}.wscales" if wscales_key in state_dict: format_name = _detect_format(state_dict[wscales_key]) if format_name not in result: result[format_name] = [] result[format_name].append(prefix) # Detect AWQ layers awq_prefixes = _detect_awq_prefixes(state_dict) if awq_prefixes: result["awq_int4"] = awq_prefixes return result def convert_awq_file( input_path: str, output_path: str, format_version: str = "1.0", ) -> Tuple[int, Dict[str, str]]: with safe_open(input_path, framework="pt", device="cpu") as f: tensors = {key: f.get_tensor(key) for key in f.keys()} metadata = dict(f.metadata()) converted = convert_awq_state_dict(tensors) # Convert layer format dict to expected metadata format # From: {"layer": "awq_int4"} # To: {"layer": {"format": "awq_int4"}} layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} metadata["_quantization_metadata"] = json.dumps( {"format_version": format_version, "layers": layers_metadata}, sort_keys=True ) save_file(converted.tensors, output_path, metadata=metadata) return len(converted.quant_layers), converted.quant_layers def convert_svdquant_file( input_path: str, output_path: str, format_version: str = "1.0", ) -> Tuple[int, Dict[str, str]]: with safe_open(input_path, framework="pt", device="cpu") as f: tensors = {key: f.get_tensor(key) for key in f.keys()} metadata = dict(f.metadata()) converted = convert_svdquant_state_dict(tensors) # Convert layer format dict to expected metadata format # From: {"layer": "svdquant_int4"} # To: {"layer": {"format": "svdquant_int4"}} layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} metadata["_quantization_metadata"] = json.dumps( {"format_version": format_version, "layers": layers_metadata}, sort_keys=True ) metadata["model_class"] = "QwenImageTransformer2DModel" save_file(converted.tensors, output_path, metadata=metadata) return len(converted.quant_layers), converted.quant_layers def convert_quantized_file( input_path: str, output_path: str, format_version: str = "1.0", quant_format: str = "auto", ) -> Tuple[int, Dict[str, str]]: """ Auto-detect and convert quantized checkpoint to ComfyUI format. Supports mixed-format models where some layers are SVDQuant and others are AWQ. Each layer is independently detected and converted to the appropriate format. Parameters ---------- input_path : str Path to input checkpoint file output_path : str Path to output checkpoint file format_version : str, optional Quantization metadata format version (default: "1.0") quant_format : str, optional Quantization format: "auto", "svdquant", or "awq" (default: "auto") Returns ------- Tuple[int, Dict[str, str]] Number of quantized layers and mapping of layer names to formats """ with safe_open(input_path, framework="pt", device="cpu") as f: tensors = {key: f.get_tensor(key) for key in f.keys()} metadata = dict(f.metadata()) # Auto-detect format if needed if quant_format == "auto": svd_prefixes = _detect_svd_prefixes(tensors) awq_prefixes = _detect_awq_prefixes(tensors) if svd_prefixes and awq_prefixes: # Mixed format - partition tensors by format and convert separately # Build sets of all quantized prefixes all_svd_prefixes = set(svd_prefixes) all_awq_prefixes = set(awq_prefixes) # Helper to check if a key belongs to a specific quantized layer def belongs_to_prefix(key, prefix): """Check if key belongs to a specific layer prefix.""" return key == prefix or key.startswith(f"{prefix}.") def is_svd_key(key): """Check if key belongs to any SVDQuant layer.""" return any(belongs_to_prefix(key, prefix) for prefix in all_svd_prefixes) def is_awq_key(key): """Check if key belongs to any AWQ layer.""" return any(belongs_to_prefix(key, prefix) for prefix in all_awq_prefixes) # Partition tensors by format svd_tensors = {} awq_tensors = {} other_tensors = {} for key, tensor in tensors.items(): if is_svd_key(key): svd_tensors[key] = tensor elif is_awq_key(key): awq_tensors[key] = tensor else: other_tensors[key] = tensor # Convert each format separately with only its relevant tensors svd_converted = _SVDQuantConverter(svd_tensors).convert() awq_converted = _AWQConverter(awq_tensors).convert() # Merge results - each converter only has its own layer tensors converted_tensors = {} # Add SVDQuant converted tensors converted_tensors.update(svd_converted.tensors) # Add AWQ converted tensors converted_tensors.update(awq_converted.tensors) # Add non-quantized tensors converted_tensors.update(other_tensors) # Merge quantization layer metadata quant_layers = {} quant_layers.update(svd_converted.quant_layers) quant_layers.update(awq_converted.quant_layers) converted = ConvertedState(converted_tensors, quant_layers) elif svd_prefixes: converted = convert_svdquant_state_dict(tensors) elif awq_prefixes: converted = convert_awq_state_dict(tensors) else: raise ValueError("No quantized layers detected in checkpoint") elif quant_format == "svdquant": converted = convert_svdquant_state_dict(tensors) elif quant_format == "awq": converted = convert_awq_state_dict(tensors) else: raise ValueError(f"Unknown quantization format: {quant_format}") # Convert layer format dict to expected metadata format # From: {"layer": "awq_int4"} # To: {"layer": {"format": "awq_int4"}} layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} metadata["_quantization_metadata"] = json.dumps( {"format_version": format_version, "layers": layers_metadata}, sort_keys=True ) metadata["model_class"] = "QwenImageTransformer2DModel" save_file(converted.tensors, output_path, metadata=metadata) return len(converted.quant_layers), converted.quant_layers