ComfyUI/tools/ptq/quantizer.py
2025-10-28 08:26:02 +01:00

60 lines
2.1 KiB
Python

import torch
import logging
from typing import Dict, Callable
import itertools
import modelopt.torch.quantization as mtq
from tools.ptq.utils import log_quant_summary, save_amax_dict, extract_amax_values
class PTQPipeline:
def __init__(self, model_patcher, quant_config: dict, filter_func=None):
self.model_patcher = model_patcher
self.diffusion_model = model_patcher.model.diffusion_model
self.quant_config = quant_config
self.filter_func = filter_func
logging.debug(f"PTQPipeline initialized with config: {quant_config}")
@torch.no_grad()
def calibrate_with_pipeline(
self,
calib_pipeline,
dataloader,
num_steps: int,
get_forward_loop: Callable
):
"""
Run calibration using the model-specific forward loop.
Args:
calib_pipeline: Calibration pipeline (e.g., FluxT2IPipe)
dataloader: DataLoader with calibration data
num_steps: Number of calibration steps
get_forward_loop: Function that returns forward_loop callable
"""
logging.info(f"Running calibration with {num_steps} steps...")
limited_dataloader = itertools.islice(dataloader, num_steps)
forward_loop = get_forward_loop(calib_pipeline, limited_dataloader)
try:
mtq.quantize(self.diffusion_model, self.quant_config, forward_loop=forward_loop)
except Exception as e:
logging.error(f"Calibration failed: {e}")
raise
try:
forward_loop()
except Exception as e:
logging.error(f"Calibration failed: {e}")
raise
logging.info("Calibration complete")
log_quant_summary(self.diffusion_model)
def get_amax_dict(self) -> Dict:
return extract_amax_values(self.diffusion_model)
def save_amax_values(self, output_path: str, metadata: dict = None):
amax_dict = self.get_amax_dict()
save_amax_dict(amax_dict, output_path, metadata=metadata)
logging.info(f"Saved amax values to {output_path}")