WIP PTQ tool

This commit is contained in:
lspindler 2025-10-28 08:26:02 +01:00
parent 9d9f98cb72
commit c4e965df06
13 changed files with 2273 additions and 0 deletions

View File

@ -0,0 +1,255 @@
import argparse
import logging
import sys
import yaml
import re
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
import json
# Add comfyui to path if needed
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
import comfy.utils
from comfy.ops import QUANT_FORMAT_MIXINS
from comfy.quant_ops import F8_E4M3_MAX, F4_E2M1_MAX
class QuantizationConfig:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
# Compile disable list patterns
self.disable_patterns = []
for pattern in self.config.get('disable_list', []):
# Convert glob-style patterns to regex
regex_pattern = pattern.replace('*', '.*')
self.disable_patterns.append(re.compile(regex_pattern))
# Parse per-layer dtype config
self.per_layer_dtype = self.config.get('per_layer_dtype', {})
self.dtype_patterns = []
for pattern, dtype in self.per_layer_dtype.items():
regex_pattern = pattern.replace('*', '.*')
self.dtype_patterns.append((re.compile(regex_pattern), dtype))
logging.info(f"Loaded config with {len(self.disable_patterns)} disable patterns")
logging.info(f"Per-layer dtype rules: {self.per_layer_dtype}")
def should_quantize(self, layer_name: str) -> bool:
for pattern in self.disable_patterns:
if pattern.match(layer_name):
logging.debug(f"Layer {layer_name} disabled by pattern {pattern.pattern}")
return False
return True
def get_dtype(self, layer_name: str) -> str:
for pattern, dtype in self.dtype_patterns:
if pattern.match(layer_name):
return dtype
return None
def load_amax_artefact(artefact_path: str) -> Dict:
logging.info(f"Loading amax artefact from {artefact_path}")
with open(artefact_path, 'r') as f:
data = json.load(f)
if 'amax_values' not in data:
raise ValueError("Invalid artefact format: missing 'amax_values' key")
metadata = data.get('metadata', {})
amax_values = data['amax_values']
logging.info(f"Loaded {len(amax_values)} amax values from artefact")
logging.info(f"Artefact metadata: {metadata}")
return data
def get_scale_fp8(amax: float, dtype: torch.dtype) -> torch.Tensor:
scale = amax / torch.finfo(dtype).max
scale_tensor = torch.tensor(scale, dtype=torch.float32)
return scale_tensor
def get_scale_nvfp4(amax: float, dtype: torch.dtype) -> torch.Tensor:
scale = amax / (F8_E4M3_MAX * F4_E2M1_MAX)
scale_tensor = torch.tensor(scale, dtype=torch.float32)
return scale_tensor
def get_scale(amax: float, dtype: torch.dtype):
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
return get_scale_fp8(amax, dtype)
elif dtype in [torch.float4_e2m1fn_x2]:
return get_scale_nvfp4(amax, dtype)
else:
raise ValueError(f"Unsupported dtype {dtype} ")
def apply_quantization(
checkpoint: Dict,
amax_values: Dict[str, float],
config: QuantizationConfig
) -> Tuple[Dict, Dict]:
quantized_dict = {}
layer_metadata = {}
for key, amax in amax_values.items():
if key.endswith(".input_quantizer"):
continue
layer_name = ".".join(key.split(".")[:-1])
if not config.should_quantize(layer_name):
logging.debug(f"Layer {layer_name} disabled by config")
continue
dtype_str = config.get_dtype(layer_name)
dtype = getattr(torch, dtype_str)
device = torch.device("cuda") # Required for NVFP4
weight = checkpoint.pop(f"{layer_name}.weight").to(device)
scale_tensor = get_scale(amax, dtype)
input_amax = amax_values.get(f"{layer_name}.input_quantizer", None)
if input_amax is not None:
input_scale = get_scale(input_amax, dtype)
quantized_dict[f"{layer_name}.input_scale"] = input_scale.clone()
# logging.info(f"Quantizing {layer_name}: amax={amax}, scale={scale_tensor:.6f}")
tensor_layout = QUANT_FORMAT_MIXINS[dtype_str]["layout_type"]
quantized_weight, layout_params = tensor_layout.quantize(
weight,
scale=scale_tensor,
dtype=dtype
)
quantized_dict[f"{layer_name}.weight_scale"] = scale_tensor.clone()
quantized_dict[f"{layer_name}.weight"] = quantized_weight.clone()
if "block_scale" in layout_params:
quantized_dict[f"{layer_name}.weight_block_scale"] = layout_params["block_scale"].clone()
# Build metadata
layer_metadata[layer_name] = {
"format": dtype_str,
"params": {}
}
logging.info(f"Quantized {len(layer_metadata)} layers")
quantized_dict = quantized_dict | checkpoint
metadata_dict = {
"_quantization_metadata": json.dumps({
"format_version": "1.0",
"layers": layer_metadata
})
}
return quantized_dict, metadata_dict
def main():
"""Main entry point for checkpoint merger."""
parser = argparse.ArgumentParser(
description="Merge calibration artifacts with checkpoint to create quantized model",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--artefact",
required=True,
help="Path to calibration artefact JSON file (amax values)"
)
parser.add_argument(
"--checkpoint",
required=True,
help="Path to original checkpoint to quantize"
)
parser.add_argument(
"--config",
required=True,
help="Path to YAML quantization config file"
)
parser.add_argument(
"--output",
required=True,
help="Output path for quantized checkpoint"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug logging"
)
args = parser.parse_args()
# Configure logging
if args.debug:
logging.basicConfig(
level=logging.DEBUG,
format='[%(levelname)s] %(name)s: %(message)s'
)
else:
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s] %(message)s'
)
# Print header
# Step 1: Load calibration artefact
logging.info("[1/5] Loading calibration artefact...")
try:
artefact_data = load_amax_artefact(args.artefact)
amax_values = artefact_data['amax_values']
except Exception as e:
logging.error(f"Failed to load artefact: {e}")
sys.exit(1)
# Step 2: Load quantization config
logging.info("[2/5] Loading quantization config...")
try:
config = QuantizationConfig(args.config)
except Exception as e:
logging.error(f"Failed to load config: {e}")
sys.exit(1)
# Step 3: Load checkpoint
logging.info("[3/5] Loading checkpoint...")
try:
checkpoint = comfy.utils.load_torch_file(args.checkpoint)
logging.info(f"Loaded checkpoint with {len(checkpoint)} keys")
except Exception as e:
logging.error(f"Failed to load checkpoint: {e}")
sys.exit(1)
# Step 4: Apply quantization
logging.info("[4/5] Applying quantization...")
try:
quantized_dict, metadata_json = apply_quantization(
checkpoint,
amax_values,
config
)
except Exception as e:
logging.error(f"Failed to apply quantization: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Step 5: Export quantized checkpoint
logging.info("[5/5] Exporting quantized checkpoint...")
try:
save_file(quantized_dict, args.output, metadata=metadata_json)
except Exception as e:
logging.error(f"Failed to export checkpoint: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,23 @@
# FLUX Quantization Config: Transformer Blocks Only
#
# Quantize only double and single transformer blocks,
# leave input/output projections in higher precision.
disable_list: [
# Disable input projections
"*img_in*",
"*txt_in*",
"*time_in*",
"*vector_in*",
"*guidance_in*",
# Disable output layers
"*final_layer*",
# Disable positional embeddings
"*pe_embedder*",
]
per_layer_dtype: {
"*": "float8_e4m3fn",
}

View File

@ -0,0 +1,27 @@
# FLUX Quantization Config: Transformer Blocks Only
#
# Quantize only double and single transformer blocks,
# leave input/output projections in higher precision.
disable_list: [
# Disable input projections
"*img_in*",
"*txt_in*",
"*time_in*",
"*vector_in*",
"*guidance_in*",
# Disable output layers
"*final_layer*",
# Disable positional embeddings
"*pe_embedder*",
"*modulation*",
"*txt_mod*",
"*img_mod*",
]
per_layer_dtype: {
"*": "float4_e2m1fn_x2",
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
from datasets import load_dataset
from torch.utils.data import Dataset
class OmniDataset(Dataset):
def __init__(self):
self.dataset = load_dataset("stepfun-ai/GEdit-Bench", split="train").filter(
lambda x: x["instruction_language"] == "en")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
sample = self.dataset[idx]
return {
"prompt": sample["instruction"],
"img_pil": sample["input_image_raw"]
}
if __name__ == "__main__":
dataset = OmniDataset()
dataset.__getitem__(0)

16
tools/ptq/dataset/t2i.py Normal file
View File

@ -0,0 +1,16 @@
import torch
import os
class PromptDataset(torch.utils.data.Dataset):
def __init__(self, calib_data_path="../data/calib_prompts.txt"):
if not os.path.exists(calib_data_path):
raise FileNotFoundError
with open(calib_data_path, "r", encoding="utf8") as file:
lst = [line.rstrip("\n") for line in file]
self.prompts = lst
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return self.prompts[idx]

30
tools/ptq/example.yml Normal file
View File

@ -0,0 +1,30 @@
# Quantization Configuration for Checkpoint Merger
#
# This file defines which layers to quantize and what precision to use.
# Patterns use glob-style syntax where * matches any characters.
# Regex patterns of layers to DISABLE quantization
# If a layer matches any pattern here, it will NOT be quantized
disable_list: [
# Example: disable input/output projection layers
# "*img_in*",
# "*txt_in*",
# "*final_layer*",
# Example: disable specific block types
# "*norm*",
# "*time_in*",
]
# Per-layer dtype configuration
# Maps layer name patterns to quantization formats
# Layers are matched in order - first match wins
per_layer_dtype: {
# Default: quantize all layers to FP8 E4M3
"*": "fp8_e4m3fn",
# Example: use different precision for specific layers
# "*attn*": "fp8_e4m3fn", # Attention layers
# "*mlp*": "fp8_e4m3fn", # MLP layers
# "*qkv*": "fp8_e4m3fn", # Q/K/V projections
}

View File

@ -0,0 +1,76 @@
"""
Model recipe registry for PTQ toolkit.
Recipes define model-specific quantization logic and are registered
via the @register_recipe decorator.
"""
from typing import Dict, Type
from .base import ModelRecipe
_RECIPE_REGISTRY: Dict[str, Type[ModelRecipe]] = {}
def register_recipe(recipe_cls: Type[ModelRecipe]):
"""
Decorator to register a model recipe.
Example:
@register_recipe
class FluxDevRecipe(ModelRecipe):
@classmethod
def name(cls):
return "flux_dev"
...
Args:
recipe_cls: Recipe class inheriting from ModelRecipe
Returns:
The recipe class (unchanged)
"""
recipe_name = recipe_cls.name()
if recipe_name in _RECIPE_REGISTRY:
raise ValueError(f"Recipe '{recipe_name}' is already registered")
_RECIPE_REGISTRY[recipe_name] = recipe_cls
return recipe_cls
def get_recipe_class(name: str) -> Type[ModelRecipe]:
"""
Get recipe class by name.
Args:
name: Recipe name (e.g., 'flux_dev')
Returns:
Recipe class
Raises:
ValueError: If recipe name is not found
"""
if name not in _RECIPE_REGISTRY:
available = ", ".join(sorted(_RECIPE_REGISTRY.keys()))
raise ValueError(
f"Unknown model type '{name}'. "
f"Available recipes: {available}"
)
return _RECIPE_REGISTRY[name]
def list_recipes():
"""
List all available recipe names.
Returns:
List of recipe names (sorted)
"""
return sorted(_RECIPE_REGISTRY.keys())
# Import recipe modules to trigger registration
from . import flux # noqa: F401, E402

124
tools/ptq/models/base.py Normal file
View File

@ -0,0 +1,124 @@
"""
Abstract base class for model quantization recipes.
Each model type (FLUX, SDXL, Qwen, etc.) implements this interface
to define model-specific quantization logic.
"""
from abc import ABC, abstractmethod
import argparse
from typing import Tuple, Any, Callable
class ModelRecipe(ABC):
"""
Abstract base class for model quantization recipes.
Each model type implements this interface to define:
- How to load the model
- How to create calibration pipeline
- How to run calibration (forward_loop)
- Which layers to quantize (filter function)
- Model-specific hyperparameters
"""
@classmethod
@abstractmethod
def name(cls) -> str:
"""
Unique identifier for this recipe (e.g., 'flux_dev', 'sdxl').
Used in CLI --model_type argument.
"""
pass
@classmethod
@abstractmethod
def add_model_args(cls, parser: argparse.ArgumentParser):
"""
Add model-specific CLI arguments.
Example:
parser.add_argument("--ckpt_path", required=True)
parser.add_argument("--clip_path", help="Optional CLIP path")
Args:
parser: ArgumentParser to add arguments to
"""
pass
@abstractmethod
def __init__(self, args):
"""
Initialize recipe with parsed CLI arguments.
Store model-specific configuration.
Args:
args: Parsed argparse.Namespace
"""
pass
@abstractmethod
def load_model(self) -> Tuple[Any, ...]:
"""
Load model from checkpoint(s).
Returns:
Tuple of (model_patcher, *other_components)
e.g., (model_patcher, clip, vae) for FLUX
First element MUST be model_patcher (ComfyUI ModelPatcher)
"""
pass
@abstractmethod
def create_calibration_pipeline(self, model_components) -> Any:
"""
Create calibration pipeline for running inference.
The pipeline should have a __call__ method that runs inference
for one calibration iteration.
Args:
model_components: Output from load_model()
Returns:
Pipeline object with __call__(steps, prompt, ...) method
"""
pass
@abstractmethod
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
"""
Return forward_loop function for ModelOptimizer calibration.
The forward_loop is called by mtq.quantize() to collect activation
statistics. It should iterate through the dataloader and run
inference using the calibration pipeline.
Args:
calib_pipeline: Output from create_calibration_pipeline()
dataloader: DataLoader with calibration data
Returns:
Callable that takes no arguments and runs calibration loop
Example:
def forward_loop():
for prompt in dataloader:
calib_pipeline(steps=4, prompt=prompt)
return forward_loop
"""
pass
@abstractmethod
def get_default_calib_steps(self) -> int:
"""
Default number of calibration steps for this model.
Returns:
Number of calibration iterations (e.g., 128 for FLUX Dev)
"""
pass

277
tools/ptq/models/flux.py Normal file
View File

@ -0,0 +1,277 @@
"""
FLUX model quantization recipes.
Defines FluxDevRecipe and FluxSchnellRecipe for quantizing FLUX models.
"""
import logging
import comfy.sd
import folder_paths
from dataclasses import dataclass
from typing import Tuple, Callable
from comfy_extras.nodes_custom_sampler import (
BasicGuider,
SamplerCustomAdvanced,
BasicScheduler,
RandomNoise,
KSamplerSelect,
)
from comfy_extras.nodes_flux import FluxGuidance
from comfy_extras.nodes_model_advanced import ModelSamplingFlux
from comfy_extras.nodes_sd3 import EmptySD3LatentImage
from nodes import CLIPTextEncode
from . import register_recipe
from .base import ModelRecipe
@dataclass
class SamplerCFG:
cfg: float
sampler_name: str
scheduler: str
denoise: float
max_shift: float
base_shift: float
class FluxT2IPipe:
def __init__(
self,
model,
clip,
batch_size,
width=1024,
height=1024,
seed=0,
sampler_cfg: SamplerCFG = None,
device="cuda",
) -> None:
self.clip = clip
self.clip_node = CLIPTextEncode()
self.batch_size = batch_size
self.width = width
self.height = height
self.max_shift = sampler_cfg.max_shift
self.base_shift = sampler_cfg.base_shift
self.device = device
self.seed = seed
self.cfg = sampler_cfg.cfg
self.scheduler_name = sampler_cfg.scheduler
self.denoise = sampler_cfg.denoise
(self.model,) = ModelSamplingFlux().patch(
model, self.max_shift, self.base_shift, self.width, self.height
)
(self.ksampler,) = KSamplerSelect().get_sampler(sampler_cfg.sampler_name)
self.latent_node = EmptySD3LatentImage()
self.guidance = FluxGuidance()
self.sampler = SamplerCustomAdvanced()
self.scheduler_node = BasicScheduler()
self.guider = BasicGuider()
self.noise_generator = RandomNoise()
def __call__(self, num_inference_steps, positive_prompt, *args, **kwargs):
(positive,) = self.clip_node.encode(self.clip, positive_prompt)
(latent_image,) = self.latent_node.generate(
self.width, self.height, self.batch_size
)
(noise,) = self.noise_generator.get_noise(self.seed)
(conditioning,) = self.guidance.append(positive, self.cfg)
(sigmas,) = self.scheduler_node.get_sigmas(
self.model, self.scheduler_name, num_inference_steps, self.denoise
)
(guider,) = self.guider.get_guider(self.model, conditioning)
out, denoised_out = self.sampler.sample(
noise, guider, self.ksampler, sigmas, latent_image
)
return out["samples"]
class FluxRecipeBase(ModelRecipe):
"""Base class for FLUX model quantization recipes."""
@classmethod
def add_model_args(cls, parser):
"""Add FLUX-specific CLI arguments."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--ckpt_path",
help="Path to full FLUX checkpoint (includes diffusion model + CLIP + T5)"
)
group.add_argument(
"--unet_path",
help="Path to FLUX diffusion model only (requires --clip_path and --t5_path)"
)
parser.add_argument(
"--clip_path",
help="Path to CLIP text encoder (required with --unet_path)"
)
parser.add_argument(
"--t5_path",
help="Path to T5 text encoder (required with --unet_path)"
)
def __init__(self, args):
"""Initialize FLUX recipe with CLI args."""
self.args = args
# Validate args
if hasattr(args, 'unet_path') and args.unet_path:
if not args.clip_path or not args.t5_path:
raise ValueError("--unet_path requires both --clip_path and --t5_path")
def load_model(self) -> Tuple:
"""Load FLUX model, CLIP, and VAE."""
if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path:
# Load from full checkpoint
logging.info(f"Loading full checkpoint from {self.args.ckpt_path}")
model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config(
self.args.ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=None
)
else:
# Load from separate files
logging.info(f"Loading diffusion model from {self.args.unet_path}")
model_options = {}
clip_type = comfy.sd.CLIPType.FLUX
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path)
model_patcher = comfy.sd.load_diffusion_model(
self.args.unet_path,
model_options=model_options
)
clip = comfy.sd.load_clip(
ckpt_paths=[clip_path1, clip_path2],
embedding_directory=folder_paths.get_folder_paths("embeddings"),
clip_type=clip_type,
model_options=model_options
)
vae = None # Not needed for calibration
return model_patcher, clip, vae
def create_calibration_pipeline(self, model_components):
"""Create FluxT2IPipe for calibration."""
model_patcher, clip, vae = model_components
return FluxT2IPipe(
model=model_patcher,
clip=clip,
batch_size=1,
width=self.get_width(),
height=self.get_height(),
seed=42,
sampler_cfg=self.get_sampler_cfg(),
device="cuda"
)
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
"""
Return forward_loop for ModelOptimizer calibration.
Iterates through the dataloader and runs full sampling
for each prompt to collect activation statistics.
"""
num_steps = self.get_inference_steps()
def forward_loop():
for i, prompt in enumerate(dataloader):
# Dataloader returns batches, extract first element
prompt_text = prompt[0] if isinstance(prompt, (list, tuple)) else prompt
logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'")
try:
# Run full sampling pipeline
calib_pipeline(num_steps, prompt_text)
except Exception as e:
logging.warning(f"Calibration step {i+1} failed: {e}")
# Continue with next prompt
return forward_loop
# Abstract methods for variants to implement
def get_width(self) -> int:
"""Image width for calibration."""
raise NotImplementedError
def get_height(self) -> int:
"""Image height for calibration."""
raise NotImplementedError
def get_sampler_cfg(self) -> SamplerCFG:
"""Sampler configuration."""
raise NotImplementedError
def get_inference_steps(self) -> int:
"""Number of sampling steps per calibration iteration."""
raise NotImplementedError
@register_recipe
class FluxDevRecipe(FluxRecipeBase):
"""FLUX Dev quantization recipe."""
@classmethod
def name(cls) -> str:
return "flux_dev"
def get_default_calib_steps(self) -> int:
return 128
def get_width(self) -> int:
return 1024
def get_height(self) -> int:
return 1024
def get_inference_steps(self) -> int:
return 30
def get_sampler_cfg(self) -> SamplerCFG:
return SamplerCFG(
cfg=3.5,
sampler_name="euler",
scheduler="simple",
denoise=1.0,
max_shift=1.15,
base_shift=0.5
)
@register_recipe
class FluxSchnellRecipe(FluxRecipeBase):
"""FLUX Schnell quantization recipe."""
@classmethod
def name(cls) -> str:
return "flux_schnell"
def get_default_calib_steps(self) -> int:
return 64
def get_width(self) -> int:
return 1024
def get_height(self) -> int:
return 1024
def get_inference_steps(self) -> int:
return 4
def get_sampler_cfg(self) -> SamplerCFG:
return SamplerCFG(
cfg=1.0,
sampler_name="euler",
scheduler="simple",
denoise=1.0,
max_shift=1.15,
base_shift=0.5
)

187
tools/ptq/quantize.py Normal file
View File

@ -0,0 +1,187 @@
import argparse
import logging
import sys
import torch.utils.data
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from tools.ptq.models import get_recipe_class, list_recipes
from tools.ptq.quantizer import PTQPipeline
from tools.ptq.utils import register_comfy_ops, FP8_CFG
from tools.ptq.dataset.t2i import PromptDataset
def main():
"""Main entry point for PTQ CLI."""
# Step 1: Parse model_type first to determine which recipe to use
parser = argparse.ArgumentParser(
description="Quantize ComfyUI models using NVIDIA ModelOptimizer",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--model_type",
required=True,
choices=list_recipes(),
help="Model recipe to use"
)
# Parse just model_type first to get recipe class
args, remaining = parser.parse_known_args()
# Step 2: Get recipe class and add its model-specific arguments
recipe_cls = get_recipe_class(args.model_type)
recipe_cls.add_model_args(parser)
# Step 3: Add common arguments
parser.add_argument(
"--output",
required=True,
help="Output path for amax artefact"
)
parser.add_argument(
"--calib_steps",
type=int,
help="Override default calibration steps"
)
parser.add_argument(
"--calib_data",
default="tools/ptq/data/calib_prompts.txt",
help="Path to calibration prompts"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode (sets logging to DEBUG and calib_steps to 1)"
)
# Step 4: Parse all arguments
args = parser.parse_args()
# Configure logging
if args.debug:
logging.basicConfig(
level=logging.DEBUG,
format='[%(levelname)s] %(name)s: %(message)s'
)
logging.info("Debug mode enabled")
else:
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s] %(message)s'
)
# Step 5: Create recipe instance
try:
recipe = recipe_cls(args)
except Exception as e:
logging.error(f"Failed to initialize recipe: {e}")
sys.exit(1)
# Debug mode overrides calibration steps
if args.debug:
calib_steps = 1
logging.debug("Debug mode: forcing calib_steps=1")
elif args.calib_steps:
calib_steps = args.calib_steps
else:
calib_steps = recipe.get_default_calib_steps()
# Print header
if args.debug:
pass
# Step 6: Register ComfyUI ops with ModelOptimizer
logging.info("Registering ComfyUI ops with ModelOptimizer...")
register_comfy_ops()
# Step 7: Load model
logging.info("[1/6] Loading model...")
try:
model_components = recipe.load_model()
model_patcher = model_components[0]
except Exception as e:
logging.error(f"Failed to load model: {e}")
sys.exit(1)
# Step 8: Create PTQ pipeline
logging.info("[2/6] Preparing quantization...")
try:
pipeline = PTQPipeline(
model_patcher,
quant_config=FP8_CFG,
filter_func=recipe.get_filter_func()
)
except Exception as e:
logging.error(f"Failed to prepare quantization: {e}")
sys.exit(1)
# Step 9: Create calibration pipeline
logging.info("[3/6] Creating calibration pipeline...")
try:
calib_pipeline = recipe.create_calibration_pipeline(model_components)
except Exception as e:
logging.error(f"Failed to create calibration pipeline: {e}")
sys.exit(1)
# Step 10: Load calibration data
logging.info(f"[4/6] Loading calibration data from {args.calib_data}")
try:
dataset = PromptDataset(args.calib_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
logging.info(f"Loaded {len(dataset)} prompts")
except Exception as e:
logging.error(f"Failed to load calibration data: {e}")
sys.exit(1)
# Step 11: Run calibration
logging.info(f"[5/6] Running calibration ({calib_steps} steps)...")
try:
pipeline.calibrate_with_pipeline(
calib_pipeline,
dataloader,
num_steps=calib_steps,
get_forward_loop=recipe.get_forward_loop
)
except Exception as e:
logging.error(f"Calibration failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Only save amax values
logging.info("[6/6] Extracting and saving amax values...")
try:
# Build metadata
metadata = {
"model_type": recipe.name(),
"calibration_steps": calib_steps,
"calibration_data": args.calib_data,
"quantization_format": "FP8_E4M3",
"debug_mode": args.debug
}
# Add checkpoint path if available
if hasattr(args, 'ckpt_path') and args.ckpt_path:
metadata["checkpoint_path"] = args.ckpt_path
pipeline.save_amax_values(args.output, metadata=metadata)
# Success!
except Exception as e:
logging.error(f"Failed to save amax values: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

59
tools/ptq/quantizer.py Normal file
View File

@ -0,0 +1,59 @@
import torch
import logging
from typing import Dict, Callable
import itertools
import modelopt.torch.quantization as mtq
from tools.ptq.utils import log_quant_summary, save_amax_dict, extract_amax_values
class PTQPipeline:
def __init__(self, model_patcher, quant_config: dict, filter_func=None):
self.model_patcher = model_patcher
self.diffusion_model = model_patcher.model.diffusion_model
self.quant_config = quant_config
self.filter_func = filter_func
logging.debug(f"PTQPipeline initialized with config: {quant_config}")
@torch.no_grad()
def calibrate_with_pipeline(
self,
calib_pipeline,
dataloader,
num_steps: int,
get_forward_loop: Callable
):
"""
Run calibration using the model-specific forward loop.
Args:
calib_pipeline: Calibration pipeline (e.g., FluxT2IPipe)
dataloader: DataLoader with calibration data
num_steps: Number of calibration steps
get_forward_loop: Function that returns forward_loop callable
"""
logging.info(f"Running calibration with {num_steps} steps...")
limited_dataloader = itertools.islice(dataloader, num_steps)
forward_loop = get_forward_loop(calib_pipeline, limited_dataloader)
try:
mtq.quantize(self.diffusion_model, self.quant_config, forward_loop=forward_loop)
except Exception as e:
logging.error(f"Calibration failed: {e}")
raise
try:
forward_loop()
except Exception as e:
logging.error(f"Calibration failed: {e}")
raise
logging.info("Calibration complete")
log_quant_summary(self.diffusion_model)
def get_amax_dict(self) -> Dict:
return extract_amax_values(self.diffusion_model)
def save_amax_values(self, output_path: str, metadata: dict = None):
amax_dict = self.get_amax_dict()
save_amax_dict(amax_dict, output_path, metadata=metadata)
logging.info(f"Saved amax values to {output_path}")

96
tools/ptq/utils.py Normal file
View File

@ -0,0 +1,96 @@
import torch
import logging
from typing import Dict, Optional
import comfy.ops
from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer
# FP8 E4M3 configuration
FP8_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
"default": {"enable": False},
},
"algorithm": "max",
}
def register_comfy_ops():
"""Register ComfyUI operations with ModelOptimizer."""
op = comfy.ops.disable_weight_init.Linear
op_name = op.__name__
if op in QuantModuleRegistry:
logging.debug("ComfyUI Linear already registered with ModelOptimizer")
return
# Register ComfyUI Linear using the same handler as torch.nn.Linear
QuantModuleRegistry.register(
{op: f"comfy.{op_name}"}
)(QuantModuleRegistry._registry[getattr(torch.nn, op_name)])
logging.info("Registered ComfyUI Linear with ModelOptimizer")
def log_quant_summary(model: torch.nn.Module, log_level=logging.INFO):
count = 0
for name, mod in model.named_modules():
if isinstance(mod, TensorQuantizer):
logging.log(log_level, f"{name:80} {mod}")
count += 1
logging.log(log_level, f"{count} TensorQuantizers found in model")
def extract_amax_values(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
amax_dict = {}
for name, module in model.named_modules():
if not isinstance(module, TensorQuantizer):
continue
if not module.is_enabled:
continue
if hasattr(module, '_amax') and module._amax is not None:
amax = module._amax
if not isinstance(amax, torch.Tensor):
amax = torch.tensor(amax, dtype=torch.float32)
amax_dict[name] = amax.clone().cpu()
logging.debug(f"Extracted amax from {name}: {amax.item():.6f}")
logging.info(f"Extracted amax values from {len(amax_dict)} quantizers")
return amax_dict
def save_amax_dict(amax_dict: Dict[str, torch.Tensor], output_path: str, metadata: Optional[Dict] = None):
import json
from datetime import datetime
logging.info(f"Saving {len(amax_dict)} amax values to {output_path}")
# Convert tensors to Python floats for JSON serialization
amax_values = {}
for key, value in amax_dict.items():
if isinstance(value, torch.Tensor):
# Convert to float (scalar) or list
if value.numel() == 1:
amax_values[key] = float(value.item())
else:
amax_values[key] = value.cpu().numpy().tolist()
else:
amax_values[key] = float(value)
# Build output with metadata
output_dict = {
"metadata": {
"timestamp": datetime.now().isoformat(),
"num_layers": len(amax_values),
**(metadata or {})
},
"amax_values": amax_values
}
# Save as formatted JSON for easy inspection
with open(output_path, 'w') as f:
json.dump(output_dict, f, indent=2, sort_keys=True)
logging.info(f"✓ Amax values saved to {output_path}")