mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
WIP PTQ tool
This commit is contained in:
parent
9d9f98cb72
commit
c4e965df06
255
tools/ptq/checkpoint_merger.py
Normal file
255
tools/ptq/checkpoint_merger.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
import re
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Add comfyui to path if needed
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
from comfy.ops import QUANT_FORMAT_MIXINS
|
||||||
|
from comfy.quant_ops import F8_E4M3_MAX, F4_E2M1_MAX
|
||||||
|
|
||||||
|
class QuantizationConfig:
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
self.config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Compile disable list patterns
|
||||||
|
self.disable_patterns = []
|
||||||
|
for pattern in self.config.get('disable_list', []):
|
||||||
|
# Convert glob-style patterns to regex
|
||||||
|
regex_pattern = pattern.replace('*', '.*')
|
||||||
|
self.disable_patterns.append(re.compile(regex_pattern))
|
||||||
|
|
||||||
|
# Parse per-layer dtype config
|
||||||
|
self.per_layer_dtype = self.config.get('per_layer_dtype', {})
|
||||||
|
self.dtype_patterns = []
|
||||||
|
for pattern, dtype in self.per_layer_dtype.items():
|
||||||
|
regex_pattern = pattern.replace('*', '.*')
|
||||||
|
self.dtype_patterns.append((re.compile(regex_pattern), dtype))
|
||||||
|
|
||||||
|
logging.info(f"Loaded config with {len(self.disable_patterns)} disable patterns")
|
||||||
|
logging.info(f"Per-layer dtype rules: {self.per_layer_dtype}")
|
||||||
|
|
||||||
|
def should_quantize(self, layer_name: str) -> bool:
|
||||||
|
for pattern in self.disable_patterns:
|
||||||
|
if pattern.match(layer_name):
|
||||||
|
logging.debug(f"Layer {layer_name} disabled by pattern {pattern.pattern}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_dtype(self, layer_name: str) -> str:
|
||||||
|
for pattern, dtype in self.dtype_patterns:
|
||||||
|
if pattern.match(layer_name):
|
||||||
|
return dtype
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_amax_artefact(artefact_path: str) -> Dict:
|
||||||
|
logging.info(f"Loading amax artefact from {artefact_path}")
|
||||||
|
|
||||||
|
with open(artefact_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if 'amax_values' not in data:
|
||||||
|
raise ValueError("Invalid artefact format: missing 'amax_values' key")
|
||||||
|
|
||||||
|
metadata = data.get('metadata', {})
|
||||||
|
amax_values = data['amax_values']
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(amax_values)} amax values from artefact")
|
||||||
|
logging.info(f"Artefact metadata: {metadata}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_scale_fp8(amax: float, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
scale = amax / torch.finfo(dtype).max
|
||||||
|
scale_tensor = torch.tensor(scale, dtype=torch.float32)
|
||||||
|
return scale_tensor
|
||||||
|
|
||||||
|
def get_scale_nvfp4(amax: float, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
scale = amax / (F8_E4M3_MAX * F4_E2M1_MAX)
|
||||||
|
scale_tensor = torch.tensor(scale, dtype=torch.float32)
|
||||||
|
return scale_tensor
|
||||||
|
|
||||||
|
def get_scale(amax: float, dtype: torch.dtype):
|
||||||
|
if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
return get_scale_fp8(amax, dtype)
|
||||||
|
elif dtype in [torch.float4_e2m1fn_x2]:
|
||||||
|
return get_scale_nvfp4(amax, dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dtype {dtype} ")
|
||||||
|
|
||||||
|
def apply_quantization(
|
||||||
|
checkpoint: Dict,
|
||||||
|
amax_values: Dict[str, float],
|
||||||
|
config: QuantizationConfig
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
quantized_dict = {}
|
||||||
|
layer_metadata = {}
|
||||||
|
|
||||||
|
for key, amax in amax_values.items():
|
||||||
|
if key.endswith(".input_quantizer"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_name = ".".join(key.split(".")[:-1])
|
||||||
|
|
||||||
|
if not config.should_quantize(layer_name):
|
||||||
|
logging.debug(f"Layer {layer_name} disabled by config")
|
||||||
|
continue
|
||||||
|
|
||||||
|
dtype_str = config.get_dtype(layer_name)
|
||||||
|
dtype = getattr(torch, dtype_str)
|
||||||
|
device = torch.device("cuda") # Required for NVFP4
|
||||||
|
|
||||||
|
weight = checkpoint.pop(f"{layer_name}.weight").to(device)
|
||||||
|
scale_tensor = get_scale(amax, dtype)
|
||||||
|
|
||||||
|
input_amax = amax_values.get(f"{layer_name}.input_quantizer", None)
|
||||||
|
if input_amax is not None:
|
||||||
|
input_scale = get_scale(input_amax, dtype)
|
||||||
|
quantized_dict[f"{layer_name}.input_scale"] = input_scale.clone()
|
||||||
|
|
||||||
|
# logging.info(f"Quantizing {layer_name}: amax={amax}, scale={scale_tensor:.6f}")
|
||||||
|
tensor_layout = QUANT_FORMAT_MIXINS[dtype_str]["layout_type"]
|
||||||
|
quantized_weight, layout_params = tensor_layout.quantize(
|
||||||
|
weight,
|
||||||
|
scale=scale_tensor,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
quantized_dict[f"{layer_name}.weight_scale"] = scale_tensor.clone()
|
||||||
|
quantized_dict[f"{layer_name}.weight"] = quantized_weight.clone()
|
||||||
|
|
||||||
|
if "block_scale" in layout_params:
|
||||||
|
quantized_dict[f"{layer_name}.weight_block_scale"] = layout_params["block_scale"].clone()
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
layer_metadata[layer_name] = {
|
||||||
|
"format": dtype_str,
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Quantized {len(layer_metadata)} layers")
|
||||||
|
|
||||||
|
quantized_dict = quantized_dict | checkpoint
|
||||||
|
|
||||||
|
metadata_dict = {
|
||||||
|
"_quantization_metadata": json.dumps({
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": layer_metadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return quantized_dict, metadata_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for checkpoint merger."""
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Merge calibration artifacts with checkpoint to create quantized model",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--artefact",
|
||||||
|
required=True,
|
||||||
|
help="Path to calibration artefact JSON file (amax values)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
required=True,
|
||||||
|
help="Path to original checkpoint to quantize"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
required=True,
|
||||||
|
help="Path to YAML quantization config file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
required=True,
|
||||||
|
help="Output path for quantized checkpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
if args.debug:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='[%(levelname)s] %(name)s: %(message)s'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='[%(levelname)s] %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
|
||||||
|
# Step 1: Load calibration artefact
|
||||||
|
logging.info("[1/5] Loading calibration artefact...")
|
||||||
|
try:
|
||||||
|
artefact_data = load_amax_artefact(args.artefact)
|
||||||
|
amax_values = artefact_data['amax_values']
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load artefact: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 2: Load quantization config
|
||||||
|
logging.info("[2/5] Loading quantization config...")
|
||||||
|
try:
|
||||||
|
config = QuantizationConfig(args.config)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load config: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 3: Load checkpoint
|
||||||
|
logging.info("[3/5] Loading checkpoint...")
|
||||||
|
try:
|
||||||
|
checkpoint = comfy.utils.load_torch_file(args.checkpoint)
|
||||||
|
logging.info(f"Loaded checkpoint with {len(checkpoint)} keys")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load checkpoint: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 4: Apply quantization
|
||||||
|
logging.info("[4/5] Applying quantization...")
|
||||||
|
try:
|
||||||
|
quantized_dict, metadata_json = apply_quantization(
|
||||||
|
checkpoint,
|
||||||
|
amax_values,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to apply quantization: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 5: Export quantized checkpoint
|
||||||
|
logging.info("[5/5] Exporting quantized checkpoint...")
|
||||||
|
try:
|
||||||
|
save_file(quantized_dict, args.output, metadata=metadata_json)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to export checkpoint: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
23
tools/ptq/configs/flux_fp8.yml
Normal file
23
tools/ptq/configs/flux_fp8.yml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# FLUX Quantization Config: Transformer Blocks Only
|
||||||
|
#
|
||||||
|
# Quantize only double and single transformer blocks,
|
||||||
|
# leave input/output projections in higher precision.
|
||||||
|
|
||||||
|
disable_list: [
|
||||||
|
# Disable input projections
|
||||||
|
"*img_in*",
|
||||||
|
"*txt_in*",
|
||||||
|
"*time_in*",
|
||||||
|
"*vector_in*",
|
||||||
|
"*guidance_in*",
|
||||||
|
|
||||||
|
# Disable output layers
|
||||||
|
"*final_layer*",
|
||||||
|
|
||||||
|
# Disable positional embeddings
|
||||||
|
"*pe_embedder*",
|
||||||
|
]
|
||||||
|
|
||||||
|
per_layer_dtype: {
|
||||||
|
"*": "float8_e4m3fn",
|
||||||
|
}
|
||||||
27
tools/ptq/configs/flux_nvfp4.yml
Normal file
27
tools/ptq/configs/flux_nvfp4.yml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# FLUX Quantization Config: Transformer Blocks Only
|
||||||
|
#
|
||||||
|
# Quantize only double and single transformer blocks,
|
||||||
|
# leave input/output projections in higher precision.
|
||||||
|
|
||||||
|
disable_list: [
|
||||||
|
# Disable input projections
|
||||||
|
"*img_in*",
|
||||||
|
"*txt_in*",
|
||||||
|
"*time_in*",
|
||||||
|
"*vector_in*",
|
||||||
|
"*guidance_in*",
|
||||||
|
|
||||||
|
# Disable output layers
|
||||||
|
"*final_layer*",
|
||||||
|
|
||||||
|
# Disable positional embeddings
|
||||||
|
"*pe_embedder*",
|
||||||
|
|
||||||
|
"*modulation*",
|
||||||
|
"*txt_mod*",
|
||||||
|
"*img_mod*",
|
||||||
|
]
|
||||||
|
|
||||||
|
per_layer_dtype: {
|
||||||
|
"*": "float4_e2m1fn_x2",
|
||||||
|
}
|
||||||
1079
tools/ptq/data/calib_prompts.txt
Normal file
1079
tools/ptq/data/calib_prompts.txt
Normal file
File diff suppressed because it is too large
Load Diff
24
tools/ptq/dataset/instruct.py
Normal file
24
tools/ptq/dataset/instruct.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from datasets import load_dataset
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class OmniDataset(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
self.dataset = load_dataset("stepfun-ai/GEdit-Bench", split="train").filter(
|
||||||
|
lambda x: x["instruction_language"] == "en")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.dataset[idx]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt": sample["instruction"],
|
||||||
|
"img_pil": sample["input_image_raw"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = OmniDataset()
|
||||||
|
dataset.__getitem__(0)
|
||||||
16
tools/ptq/dataset/t2i.py
Normal file
16
tools/ptq/dataset/t2i.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
class PromptDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, calib_data_path="../data/calib_prompts.txt"):
|
||||||
|
if not os.path.exists(calib_data_path):
|
||||||
|
raise FileNotFoundError
|
||||||
|
with open(calib_data_path, "r", encoding="utf8") as file:
|
||||||
|
lst = [line.rstrip("\n") for line in file]
|
||||||
|
self.prompts = lst
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.prompts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.prompts[idx]
|
||||||
30
tools/ptq/example.yml
Normal file
30
tools/ptq/example.yml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Quantization Configuration for Checkpoint Merger
|
||||||
|
#
|
||||||
|
# This file defines which layers to quantize and what precision to use.
|
||||||
|
# Patterns use glob-style syntax where * matches any characters.
|
||||||
|
|
||||||
|
# Regex patterns of layers to DISABLE quantization
|
||||||
|
# If a layer matches any pattern here, it will NOT be quantized
|
||||||
|
disable_list: [
|
||||||
|
# Example: disable input/output projection layers
|
||||||
|
# "*img_in*",
|
||||||
|
# "*txt_in*",
|
||||||
|
# "*final_layer*",
|
||||||
|
|
||||||
|
# Example: disable specific block types
|
||||||
|
# "*norm*",
|
||||||
|
# "*time_in*",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Per-layer dtype configuration
|
||||||
|
# Maps layer name patterns to quantization formats
|
||||||
|
# Layers are matched in order - first match wins
|
||||||
|
per_layer_dtype: {
|
||||||
|
# Default: quantize all layers to FP8 E4M3
|
||||||
|
"*": "fp8_e4m3fn",
|
||||||
|
|
||||||
|
# Example: use different precision for specific layers
|
||||||
|
# "*attn*": "fp8_e4m3fn", # Attention layers
|
||||||
|
# "*mlp*": "fp8_e4m3fn", # MLP layers
|
||||||
|
# "*qkv*": "fp8_e4m3fn", # Q/K/V projections
|
||||||
|
}
|
||||||
76
tools/ptq/models/__init__.py
Normal file
76
tools/ptq/models/__init__.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
Model recipe registry for PTQ toolkit.
|
||||||
|
|
||||||
|
Recipes define model-specific quantization logic and are registered
|
||||||
|
via the @register_recipe decorator.
|
||||||
|
"""
|
||||||
|
from typing import Dict, Type
|
||||||
|
from .base import ModelRecipe
|
||||||
|
|
||||||
|
|
||||||
|
_RECIPE_REGISTRY: Dict[str, Type[ModelRecipe]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_recipe(recipe_cls: Type[ModelRecipe]):
|
||||||
|
"""
|
||||||
|
Decorator to register a model recipe.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_recipe
|
||||||
|
class FluxDevRecipe(ModelRecipe):
|
||||||
|
@classmethod
|
||||||
|
def name(cls):
|
||||||
|
return "flux_dev"
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recipe_cls: Recipe class inheriting from ModelRecipe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The recipe class (unchanged)
|
||||||
|
"""
|
||||||
|
recipe_name = recipe_cls.name()
|
||||||
|
if recipe_name in _RECIPE_REGISTRY:
|
||||||
|
raise ValueError(f"Recipe '{recipe_name}' is already registered")
|
||||||
|
|
||||||
|
_RECIPE_REGISTRY[recipe_name] = recipe_cls
|
||||||
|
return recipe_cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_recipe_class(name: str) -> Type[ModelRecipe]:
|
||||||
|
"""
|
||||||
|
Get recipe class by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Recipe name (e.g., 'flux_dev')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recipe class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If recipe name is not found
|
||||||
|
"""
|
||||||
|
if name not in _RECIPE_REGISTRY:
|
||||||
|
available = ", ".join(sorted(_RECIPE_REGISTRY.keys()))
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model type '{name}'. "
|
||||||
|
f"Available recipes: {available}"
|
||||||
|
)
|
||||||
|
return _RECIPE_REGISTRY[name]
|
||||||
|
|
||||||
|
|
||||||
|
def list_recipes():
|
||||||
|
"""
|
||||||
|
List all available recipe names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recipe names (sorted)
|
||||||
|
"""
|
||||||
|
return sorted(_RECIPE_REGISTRY.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# Import recipe modules to trigger registration
|
||||||
|
from . import flux # noqa: F401, E402
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
124
tools/ptq/models/base.py
Normal file
124
tools/ptq/models/base.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Abstract base class for model quantization recipes.
|
||||||
|
|
||||||
|
Each model type (FLUX, SDXL, Qwen, etc.) implements this interface
|
||||||
|
to define model-specific quantization logic.
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import argparse
|
||||||
|
from typing import Tuple, Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRecipe(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for model quantization recipes.
|
||||||
|
|
||||||
|
Each model type implements this interface to define:
|
||||||
|
- How to load the model
|
||||||
|
- How to create calibration pipeline
|
||||||
|
- How to run calibration (forward_loop)
|
||||||
|
- Which layers to quantize (filter function)
|
||||||
|
- Model-specific hyperparameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def name(cls) -> str:
|
||||||
|
"""
|
||||||
|
Unique identifier for this recipe (e.g., 'flux_dev', 'sdxl').
|
||||||
|
Used in CLI --model_type argument.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def add_model_args(cls, parser: argparse.ArgumentParser):
|
||||||
|
"""
|
||||||
|
Add model-specific CLI arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
parser.add_argument("--ckpt_path", required=True)
|
||||||
|
parser.add_argument("--clip_path", help="Optional CLIP path")
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: ArgumentParser to add arguments to
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, args):
|
||||||
|
"""
|
||||||
|
Initialize recipe with parsed CLI arguments.
|
||||||
|
Store model-specific configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Parsed argparse.Namespace
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self) -> Tuple[Any, ...]:
|
||||||
|
"""
|
||||||
|
Load model from checkpoint(s).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model_patcher, *other_components)
|
||||||
|
e.g., (model_patcher, clip, vae) for FLUX
|
||||||
|
|
||||||
|
First element MUST be model_patcher (ComfyUI ModelPatcher)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_calibration_pipeline(self, model_components) -> Any:
|
||||||
|
"""
|
||||||
|
Create calibration pipeline for running inference.
|
||||||
|
|
||||||
|
The pipeline should have a __call__ method that runs inference
|
||||||
|
for one calibration iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_components: Output from load_model()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pipeline object with __call__(steps, prompt, ...) method
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
|
||||||
|
"""
|
||||||
|
Return forward_loop function for ModelOptimizer calibration.
|
||||||
|
|
||||||
|
The forward_loop is called by mtq.quantize() to collect activation
|
||||||
|
statistics. It should iterate through the dataloader and run
|
||||||
|
inference using the calibration pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calib_pipeline: Output from create_calibration_pipeline()
|
||||||
|
dataloader: DataLoader with calibration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable that takes no arguments and runs calibration loop
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def forward_loop():
|
||||||
|
for prompt in dataloader:
|
||||||
|
calib_pipeline(steps=4, prompt=prompt)
|
||||||
|
return forward_loop
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_default_calib_steps(self) -> int:
|
||||||
|
"""
|
||||||
|
Default number of calibration steps for this model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of calibration iterations (e.g., 128 for FLUX Dev)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
277
tools/ptq/models/flux.py
Normal file
277
tools/ptq/models/flux.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
"""
|
||||||
|
FLUX model quantization recipes.
|
||||||
|
|
||||||
|
Defines FluxDevRecipe and FluxSchnellRecipe for quantizing FLUX models.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import comfy.sd
|
||||||
|
import folder_paths
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
|
||||||
|
from comfy_extras.nodes_custom_sampler import (
|
||||||
|
BasicGuider,
|
||||||
|
SamplerCustomAdvanced,
|
||||||
|
BasicScheduler,
|
||||||
|
RandomNoise,
|
||||||
|
KSamplerSelect,
|
||||||
|
)
|
||||||
|
from comfy_extras.nodes_flux import FluxGuidance
|
||||||
|
from comfy_extras.nodes_model_advanced import ModelSamplingFlux
|
||||||
|
from comfy_extras.nodes_sd3 import EmptySD3LatentImage
|
||||||
|
from nodes import CLIPTextEncode
|
||||||
|
|
||||||
|
from . import register_recipe
|
||||||
|
from .base import ModelRecipe
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplerCFG:
|
||||||
|
cfg: float
|
||||||
|
sampler_name: str
|
||||||
|
scheduler: str
|
||||||
|
denoise: float
|
||||||
|
max_shift: float
|
||||||
|
base_shift: float
|
||||||
|
|
||||||
|
class FluxT2IPipe:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
clip,
|
||||||
|
batch_size,
|
||||||
|
width=1024,
|
||||||
|
height=1024,
|
||||||
|
seed=0,
|
||||||
|
sampler_cfg: SamplerCFG = None,
|
||||||
|
device="cuda",
|
||||||
|
) -> None:
|
||||||
|
self.clip = clip
|
||||||
|
self.clip_node = CLIPTextEncode()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.max_shift = sampler_cfg.max_shift
|
||||||
|
self.base_shift = sampler_cfg.base_shift
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
|
self.cfg = sampler_cfg.cfg
|
||||||
|
self.scheduler_name = sampler_cfg.scheduler
|
||||||
|
self.denoise = sampler_cfg.denoise
|
||||||
|
|
||||||
|
(self.model,) = ModelSamplingFlux().patch(
|
||||||
|
model, self.max_shift, self.base_shift, self.width, self.height
|
||||||
|
)
|
||||||
|
(self.ksampler,) = KSamplerSelect().get_sampler(sampler_cfg.sampler_name)
|
||||||
|
self.latent_node = EmptySD3LatentImage()
|
||||||
|
self.guidance = FluxGuidance()
|
||||||
|
self.sampler = SamplerCustomAdvanced()
|
||||||
|
self.scheduler_node = BasicScheduler()
|
||||||
|
self.guider = BasicGuider()
|
||||||
|
self.noise_generator = RandomNoise()
|
||||||
|
|
||||||
|
def __call__(self, num_inference_steps, positive_prompt, *args, **kwargs):
|
||||||
|
(positive,) = self.clip_node.encode(self.clip, positive_prompt)
|
||||||
|
(latent_image,) = self.latent_node.generate(
|
||||||
|
self.width, self.height, self.batch_size
|
||||||
|
)
|
||||||
|
(noise,) = self.noise_generator.get_noise(self.seed)
|
||||||
|
|
||||||
|
(conditioning,) = self.guidance.append(positive, self.cfg)
|
||||||
|
(sigmas,) = self.scheduler_node.get_sigmas(
|
||||||
|
self.model, self.scheduler_name, num_inference_steps, self.denoise
|
||||||
|
)
|
||||||
|
(guider,) = self.guider.get_guider(self.model, conditioning)
|
||||||
|
|
||||||
|
out, denoised_out = self.sampler.sample(
|
||||||
|
noise, guider, self.ksampler, sigmas, latent_image
|
||||||
|
)
|
||||||
|
|
||||||
|
return out["samples"]
|
||||||
|
|
||||||
|
class FluxRecipeBase(ModelRecipe):
|
||||||
|
"""Base class for FLUX model quantization recipes."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_model_args(cls, parser):
|
||||||
|
"""Add FLUX-specific CLI arguments."""
|
||||||
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
group.add_argument(
|
||||||
|
"--ckpt_path",
|
||||||
|
help="Path to full FLUX checkpoint (includes diffusion model + CLIP + T5)"
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--unet_path",
|
||||||
|
help="Path to FLUX diffusion model only (requires --clip_path and --t5_path)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_path",
|
||||||
|
help="Path to CLIP text encoder (required with --unet_path)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5_path",
|
||||||
|
help="Path to T5 text encoder (required with --unet_path)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
"""Initialize FLUX recipe with CLI args."""
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
# Validate args
|
||||||
|
if hasattr(args, 'unet_path') and args.unet_path:
|
||||||
|
if not args.clip_path or not args.t5_path:
|
||||||
|
raise ValueError("--unet_path requires both --clip_path and --t5_path")
|
||||||
|
|
||||||
|
def load_model(self) -> Tuple:
|
||||||
|
"""Load FLUX model, CLIP, and VAE."""
|
||||||
|
if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path:
|
||||||
|
# Load from full checkpoint
|
||||||
|
logging.info(f"Loading full checkpoint from {self.args.ckpt_path}")
|
||||||
|
model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config(
|
||||||
|
self.args.ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=True,
|
||||||
|
embedding_directory=None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Load from separate files
|
||||||
|
logging.info(f"Loading diffusion model from {self.args.unet_path}")
|
||||||
|
model_options = {}
|
||||||
|
clip_type = comfy.sd.CLIPType.FLUX
|
||||||
|
|
||||||
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path)
|
||||||
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path)
|
||||||
|
|
||||||
|
model_patcher = comfy.sd.load_diffusion_model(
|
||||||
|
self.args.unet_path,
|
||||||
|
model_options=model_options
|
||||||
|
)
|
||||||
|
clip = comfy.sd.load_clip(
|
||||||
|
ckpt_paths=[clip_path1, clip_path2],
|
||||||
|
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||||
|
clip_type=clip_type,
|
||||||
|
model_options=model_options
|
||||||
|
)
|
||||||
|
vae = None # Not needed for calibration
|
||||||
|
|
||||||
|
return model_patcher, clip, vae
|
||||||
|
|
||||||
|
def create_calibration_pipeline(self, model_components):
|
||||||
|
"""Create FluxT2IPipe for calibration."""
|
||||||
|
model_patcher, clip, vae = model_components
|
||||||
|
|
||||||
|
return FluxT2IPipe(
|
||||||
|
model=model_patcher,
|
||||||
|
clip=clip,
|
||||||
|
batch_size=1,
|
||||||
|
width=self.get_width(),
|
||||||
|
height=self.get_height(),
|
||||||
|
seed=42,
|
||||||
|
sampler_cfg=self.get_sampler_cfg(),
|
||||||
|
device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
|
||||||
|
"""
|
||||||
|
Return forward_loop for ModelOptimizer calibration.
|
||||||
|
|
||||||
|
Iterates through the dataloader and runs full sampling
|
||||||
|
for each prompt to collect activation statistics.
|
||||||
|
"""
|
||||||
|
num_steps = self.get_inference_steps()
|
||||||
|
|
||||||
|
def forward_loop():
|
||||||
|
for i, prompt in enumerate(dataloader):
|
||||||
|
# Dataloader returns batches, extract first element
|
||||||
|
prompt_text = prompt[0] if isinstance(prompt, (list, tuple)) else prompt
|
||||||
|
|
||||||
|
logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run full sampling pipeline
|
||||||
|
calib_pipeline(num_steps, prompt_text)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Calibration step {i+1} failed: {e}")
|
||||||
|
# Continue with next prompt
|
||||||
|
|
||||||
|
return forward_loop
|
||||||
|
|
||||||
|
# Abstract methods for variants to implement
|
||||||
|
def get_width(self) -> int:
|
||||||
|
"""Image width for calibration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_height(self) -> int:
|
||||||
|
"""Image height for calibration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_sampler_cfg(self) -> SamplerCFG:
|
||||||
|
"""Sampler configuration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_inference_steps(self) -> int:
|
||||||
|
"""Number of sampling steps per calibration iteration."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@register_recipe
|
||||||
|
class FluxDevRecipe(FluxRecipeBase):
|
||||||
|
"""FLUX Dev quantization recipe."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls) -> str:
|
||||||
|
return "flux_dev"
|
||||||
|
|
||||||
|
def get_default_calib_steps(self) -> int:
|
||||||
|
return 128
|
||||||
|
|
||||||
|
def get_width(self) -> int:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
def get_height(self) -> int:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
def get_inference_steps(self) -> int:
|
||||||
|
return 30
|
||||||
|
|
||||||
|
def get_sampler_cfg(self) -> SamplerCFG:
|
||||||
|
return SamplerCFG(
|
||||||
|
cfg=3.5,
|
||||||
|
sampler_name="euler",
|
||||||
|
scheduler="simple",
|
||||||
|
denoise=1.0,
|
||||||
|
max_shift=1.15,
|
||||||
|
base_shift=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_recipe
|
||||||
|
class FluxSchnellRecipe(FluxRecipeBase):
|
||||||
|
"""FLUX Schnell quantization recipe."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls) -> str:
|
||||||
|
return "flux_schnell"
|
||||||
|
|
||||||
|
def get_default_calib_steps(self) -> int:
|
||||||
|
return 64
|
||||||
|
|
||||||
|
def get_width(self) -> int:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
def get_height(self) -> int:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
def get_inference_steps(self) -> int:
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def get_sampler_cfg(self) -> SamplerCFG:
|
||||||
|
return SamplerCFG(
|
||||||
|
cfg=1.0,
|
||||||
|
sampler_name="euler",
|
||||||
|
scheduler="simple",
|
||||||
|
denoise=1.0,
|
||||||
|
max_shift=1.15,
|
||||||
|
base_shift=0.5
|
||||||
|
)
|
||||||
187
tools/ptq/quantize.py
Normal file
187
tools/ptq/quantize.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
||||||
|
|
||||||
|
from tools.ptq.models import get_recipe_class, list_recipes
|
||||||
|
from tools.ptq.quantizer import PTQPipeline
|
||||||
|
from tools.ptq.utils import register_comfy_ops, FP8_CFG
|
||||||
|
from tools.ptq.dataset.t2i import PromptDataset
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for PTQ CLI."""
|
||||||
|
|
||||||
|
# Step 1: Parse model_type first to determine which recipe to use
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Quantize ComfyUI models using NVIDIA ModelOptimizer",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
required=True,
|
||||||
|
choices=list_recipes(),
|
||||||
|
help="Model recipe to use"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse just model_type first to get recipe class
|
||||||
|
args, remaining = parser.parse_known_args()
|
||||||
|
|
||||||
|
# Step 2: Get recipe class and add its model-specific arguments
|
||||||
|
recipe_cls = get_recipe_class(args.model_type)
|
||||||
|
recipe_cls.add_model_args(parser)
|
||||||
|
|
||||||
|
# Step 3: Add common arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
required=True,
|
||||||
|
help="Output path for amax artefact"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--calib_steps",
|
||||||
|
type=int,
|
||||||
|
help="Override default calibration steps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--calib_data",
|
||||||
|
default="tools/ptq/data/calib_prompts.txt",
|
||||||
|
help="Path to calibration prompts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="Random seed"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug mode (sets logging to DEBUG and calib_steps to 1)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Parse all arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
if args.debug:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format='[%(levelname)s] %(name)s: %(message)s'
|
||||||
|
)
|
||||||
|
logging.info("Debug mode enabled")
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='[%(levelname)s] %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Create recipe instance
|
||||||
|
try:
|
||||||
|
recipe = recipe_cls(args)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize recipe: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Debug mode overrides calibration steps
|
||||||
|
if args.debug:
|
||||||
|
calib_steps = 1
|
||||||
|
logging.debug("Debug mode: forcing calib_steps=1")
|
||||||
|
elif args.calib_steps:
|
||||||
|
calib_steps = args.calib_steps
|
||||||
|
else:
|
||||||
|
calib_steps = recipe.get_default_calib_steps()
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
if args.debug:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Step 6: Register ComfyUI ops with ModelOptimizer
|
||||||
|
logging.info("Registering ComfyUI ops with ModelOptimizer...")
|
||||||
|
register_comfy_ops()
|
||||||
|
|
||||||
|
# Step 7: Load model
|
||||||
|
logging.info("[1/6] Loading model...")
|
||||||
|
try:
|
||||||
|
model_components = recipe.load_model()
|
||||||
|
model_patcher = model_components[0]
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load model: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 8: Create PTQ pipeline
|
||||||
|
logging.info("[2/6] Preparing quantization...")
|
||||||
|
try:
|
||||||
|
pipeline = PTQPipeline(
|
||||||
|
model_patcher,
|
||||||
|
quant_config=FP8_CFG,
|
||||||
|
filter_func=recipe.get_filter_func()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to prepare quantization: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 9: Create calibration pipeline
|
||||||
|
logging.info("[3/6] Creating calibration pipeline...")
|
||||||
|
try:
|
||||||
|
calib_pipeline = recipe.create_calibration_pipeline(model_components)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to create calibration pipeline: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 10: Load calibration data
|
||||||
|
logging.info(f"[4/6] Loading calibration data from {args.calib_data}")
|
||||||
|
try:
|
||||||
|
dataset = PromptDataset(args.calib_data)
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
|
||||||
|
logging.info(f"Loaded {len(dataset)} prompts")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load calibration data: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 11: Run calibration
|
||||||
|
logging.info(f"[5/6] Running calibration ({calib_steps} steps)...")
|
||||||
|
try:
|
||||||
|
pipeline.calibrate_with_pipeline(
|
||||||
|
calib_pipeline,
|
||||||
|
dataloader,
|
||||||
|
num_steps=calib_steps,
|
||||||
|
get_forward_loop=recipe.get_forward_loop
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Calibration failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Only save amax values
|
||||||
|
logging.info("[6/6] Extracting and saving amax values...")
|
||||||
|
try:
|
||||||
|
# Build metadata
|
||||||
|
metadata = {
|
||||||
|
"model_type": recipe.name(),
|
||||||
|
"calibration_steps": calib_steps,
|
||||||
|
"calibration_data": args.calib_data,
|
||||||
|
"quantization_format": "FP8_E4M3",
|
||||||
|
"debug_mode": args.debug
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add checkpoint path if available
|
||||||
|
if hasattr(args, 'ckpt_path') and args.ckpt_path:
|
||||||
|
metadata["checkpoint_path"] = args.ckpt_path
|
||||||
|
|
||||||
|
pipeline.save_amax_values(args.output, metadata=metadata)
|
||||||
|
|
||||||
|
# Success!
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to save amax values: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
59
tools/ptq/quantizer.py
Normal file
59
tools/ptq/quantizer.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Callable
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import modelopt.torch.quantization as mtq
|
||||||
|
from tools.ptq.utils import log_quant_summary, save_amax_dict, extract_amax_values
|
||||||
|
|
||||||
|
class PTQPipeline:
|
||||||
|
def __init__(self, model_patcher, quant_config: dict, filter_func=None):
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
self.diffusion_model = model_patcher.model.diffusion_model
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.filter_func = filter_func
|
||||||
|
|
||||||
|
logging.debug(f"PTQPipeline initialized with config: {quant_config}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def calibrate_with_pipeline(
|
||||||
|
self,
|
||||||
|
calib_pipeline,
|
||||||
|
dataloader,
|
||||||
|
num_steps: int,
|
||||||
|
get_forward_loop: Callable
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run calibration using the model-specific forward loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calib_pipeline: Calibration pipeline (e.g., FluxT2IPipe)
|
||||||
|
dataloader: DataLoader with calibration data
|
||||||
|
num_steps: Number of calibration steps
|
||||||
|
get_forward_loop: Function that returns forward_loop callable
|
||||||
|
"""
|
||||||
|
logging.info(f"Running calibration with {num_steps} steps...")
|
||||||
|
limited_dataloader = itertools.islice(dataloader, num_steps)
|
||||||
|
forward_loop = get_forward_loop(calib_pipeline, limited_dataloader)
|
||||||
|
try:
|
||||||
|
mtq.quantize(self.diffusion_model, self.quant_config, forward_loop=forward_loop)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Calibration failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
forward_loop()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Calibration failed: {e}")
|
||||||
|
raise
|
||||||
|
logging.info("Calibration complete")
|
||||||
|
log_quant_summary(self.diffusion_model)
|
||||||
|
|
||||||
|
def get_amax_dict(self) -> Dict:
|
||||||
|
return extract_amax_values(self.diffusion_model)
|
||||||
|
|
||||||
|
def save_amax_values(self, output_path: str, metadata: dict = None):
|
||||||
|
amax_dict = self.get_amax_dict()
|
||||||
|
save_amax_dict(amax_dict, output_path, metadata=metadata)
|
||||||
|
logging.info(f"Saved amax values to {output_path}")
|
||||||
|
|
||||||
96
tools/ptq/utils.py
Normal file
96
tools/ptq/utils.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer
|
||||||
|
|
||||||
|
|
||||||
|
# FP8 E4M3 configuration
|
||||||
|
FP8_CFG = {
|
||||||
|
"quant_cfg": {
|
||||||
|
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
|
||||||
|
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
|
||||||
|
"default": {"enable": False},
|
||||||
|
},
|
||||||
|
"algorithm": "max",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_comfy_ops():
|
||||||
|
"""Register ComfyUI operations with ModelOptimizer."""
|
||||||
|
op = comfy.ops.disable_weight_init.Linear
|
||||||
|
op_name = op.__name__
|
||||||
|
|
||||||
|
if op in QuantModuleRegistry:
|
||||||
|
logging.debug("ComfyUI Linear already registered with ModelOptimizer")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Register ComfyUI Linear using the same handler as torch.nn.Linear
|
||||||
|
QuantModuleRegistry.register(
|
||||||
|
{op: f"comfy.{op_name}"}
|
||||||
|
)(QuantModuleRegistry._registry[getattr(torch.nn, op_name)])
|
||||||
|
|
||||||
|
logging.info("Registered ComfyUI Linear with ModelOptimizer")
|
||||||
|
|
||||||
|
def log_quant_summary(model: torch.nn.Module, log_level=logging.INFO):
|
||||||
|
count = 0
|
||||||
|
for name, mod in model.named_modules():
|
||||||
|
if isinstance(mod, TensorQuantizer):
|
||||||
|
logging.log(log_level, f"{name:80} {mod}")
|
||||||
|
count += 1
|
||||||
|
logging.log(log_level, f"{count} TensorQuantizers found in model")
|
||||||
|
|
||||||
|
def extract_amax_values(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
amax_dict = {}
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not isinstance(module, TensorQuantizer):
|
||||||
|
continue
|
||||||
|
if not module.is_enabled:
|
||||||
|
continue
|
||||||
|
if hasattr(module, '_amax') and module._amax is not None:
|
||||||
|
amax = module._amax
|
||||||
|
if not isinstance(amax, torch.Tensor):
|
||||||
|
amax = torch.tensor(amax, dtype=torch.float32)
|
||||||
|
|
||||||
|
amax_dict[name] = amax.clone().cpu()
|
||||||
|
logging.debug(f"Extracted amax from {name}: {amax.item():.6f}")
|
||||||
|
|
||||||
|
logging.info(f"Extracted amax values from {len(amax_dict)} quantizers")
|
||||||
|
return amax_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_amax_dict(amax_dict: Dict[str, torch.Tensor], output_path: str, metadata: Optional[Dict] = None):
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logging.info(f"Saving {len(amax_dict)} amax values to {output_path}")
|
||||||
|
|
||||||
|
# Convert tensors to Python floats for JSON serialization
|
||||||
|
amax_values = {}
|
||||||
|
for key, value in amax_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Convert to float (scalar) or list
|
||||||
|
if value.numel() == 1:
|
||||||
|
amax_values[key] = float(value.item())
|
||||||
|
else:
|
||||||
|
amax_values[key] = value.cpu().numpy().tolist()
|
||||||
|
else:
|
||||||
|
amax_values[key] = float(value)
|
||||||
|
|
||||||
|
# Build output with metadata
|
||||||
|
output_dict = {
|
||||||
|
"metadata": {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"num_layers": len(amax_values),
|
||||||
|
**(metadata or {})
|
||||||
|
},
|
||||||
|
"amax_values": amax_values
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save as formatted JSON for easy inspection
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(output_dict, f, indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
logging.info(f"✓ Amax values saved to {output_path}")
|
||||||
Loading…
Reference in New Issue
Block a user