ComfyUI/tools/ptq/models/flux.py
2025-10-28 08:26:02 +01:00

278 lines
8.4 KiB
Python

"""
FLUX model quantization recipes.
Defines FluxDevRecipe and FluxSchnellRecipe for quantizing FLUX models.
"""
import logging
import comfy.sd
import folder_paths
from dataclasses import dataclass
from typing import Tuple, Callable
from comfy_extras.nodes_custom_sampler import (
BasicGuider,
SamplerCustomAdvanced,
BasicScheduler,
RandomNoise,
KSamplerSelect,
)
from comfy_extras.nodes_flux import FluxGuidance
from comfy_extras.nodes_model_advanced import ModelSamplingFlux
from comfy_extras.nodes_sd3 import EmptySD3LatentImage
from nodes import CLIPTextEncode
from . import register_recipe
from .base import ModelRecipe
@dataclass
class SamplerCFG:
cfg: float
sampler_name: str
scheduler: str
denoise: float
max_shift: float
base_shift: float
class FluxT2IPipe:
def __init__(
self,
model,
clip,
batch_size,
width=1024,
height=1024,
seed=0,
sampler_cfg: SamplerCFG = None,
device="cuda",
) -> None:
self.clip = clip
self.clip_node = CLIPTextEncode()
self.batch_size = batch_size
self.width = width
self.height = height
self.max_shift = sampler_cfg.max_shift
self.base_shift = sampler_cfg.base_shift
self.device = device
self.seed = seed
self.cfg = sampler_cfg.cfg
self.scheduler_name = sampler_cfg.scheduler
self.denoise = sampler_cfg.denoise
(self.model,) = ModelSamplingFlux().patch(
model, self.max_shift, self.base_shift, self.width, self.height
)
(self.ksampler,) = KSamplerSelect().get_sampler(sampler_cfg.sampler_name)
self.latent_node = EmptySD3LatentImage()
self.guidance = FluxGuidance()
self.sampler = SamplerCustomAdvanced()
self.scheduler_node = BasicScheduler()
self.guider = BasicGuider()
self.noise_generator = RandomNoise()
def __call__(self, num_inference_steps, positive_prompt, *args, **kwargs):
(positive,) = self.clip_node.encode(self.clip, positive_prompt)
(latent_image,) = self.latent_node.generate(
self.width, self.height, self.batch_size
)
(noise,) = self.noise_generator.get_noise(self.seed)
(conditioning,) = self.guidance.append(positive, self.cfg)
(sigmas,) = self.scheduler_node.get_sigmas(
self.model, self.scheduler_name, num_inference_steps, self.denoise
)
(guider,) = self.guider.get_guider(self.model, conditioning)
out, denoised_out = self.sampler.sample(
noise, guider, self.ksampler, sigmas, latent_image
)
return out["samples"]
class FluxRecipeBase(ModelRecipe):
"""Base class for FLUX model quantization recipes."""
@classmethod
def add_model_args(cls, parser):
"""Add FLUX-specific CLI arguments."""
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--ckpt_path",
help="Path to full FLUX checkpoint (includes diffusion model + CLIP + T5)"
)
group.add_argument(
"--unet_path",
help="Path to FLUX diffusion model only (requires --clip_path and --t5_path)"
)
parser.add_argument(
"--clip_path",
help="Path to CLIP text encoder (required with --unet_path)"
)
parser.add_argument(
"--t5_path",
help="Path to T5 text encoder (required with --unet_path)"
)
def __init__(self, args):
"""Initialize FLUX recipe with CLI args."""
self.args = args
# Validate args
if hasattr(args, 'unet_path') and args.unet_path:
if not args.clip_path or not args.t5_path:
raise ValueError("--unet_path requires both --clip_path and --t5_path")
def load_model(self) -> Tuple:
"""Load FLUX model, CLIP, and VAE."""
if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path:
# Load from full checkpoint
logging.info(f"Loading full checkpoint from {self.args.ckpt_path}")
model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config(
self.args.ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=None
)
else:
# Load from separate files
logging.info(f"Loading diffusion model from {self.args.unet_path}")
model_options = {}
clip_type = comfy.sd.CLIPType.FLUX
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path)
model_patcher = comfy.sd.load_diffusion_model(
self.args.unet_path,
model_options=model_options
)
clip = comfy.sd.load_clip(
ckpt_paths=[clip_path1, clip_path2],
embedding_directory=folder_paths.get_folder_paths("embeddings"),
clip_type=clip_type,
model_options=model_options
)
vae = None # Not needed for calibration
return model_patcher, clip, vae
def create_calibration_pipeline(self, model_components):
"""Create FluxT2IPipe for calibration."""
model_patcher, clip, vae = model_components
return FluxT2IPipe(
model=model_patcher,
clip=clip,
batch_size=1,
width=self.get_width(),
height=self.get_height(),
seed=42,
sampler_cfg=self.get_sampler_cfg(),
device="cuda"
)
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
"""
Return forward_loop for ModelOptimizer calibration.
Iterates through the dataloader and runs full sampling
for each prompt to collect activation statistics.
"""
num_steps = self.get_inference_steps()
def forward_loop():
for i, prompt in enumerate(dataloader):
# Dataloader returns batches, extract first element
prompt_text = prompt[0] if isinstance(prompt, (list, tuple)) else prompt
logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'")
try:
# Run full sampling pipeline
calib_pipeline(num_steps, prompt_text)
except Exception as e:
logging.warning(f"Calibration step {i+1} failed: {e}")
# Continue with next prompt
return forward_loop
# Abstract methods for variants to implement
def get_width(self) -> int:
"""Image width for calibration."""
raise NotImplementedError
def get_height(self) -> int:
"""Image height for calibration."""
raise NotImplementedError
def get_sampler_cfg(self) -> SamplerCFG:
"""Sampler configuration."""
raise NotImplementedError
def get_inference_steps(self) -> int:
"""Number of sampling steps per calibration iteration."""
raise NotImplementedError
@register_recipe
class FluxDevRecipe(FluxRecipeBase):
"""FLUX Dev quantization recipe."""
@classmethod
def name(cls) -> str:
return "flux_dev"
def get_default_calib_steps(self) -> int:
return 128
def get_width(self) -> int:
return 1024
def get_height(self) -> int:
return 1024
def get_inference_steps(self) -> int:
return 30
def get_sampler_cfg(self) -> SamplerCFG:
return SamplerCFG(
cfg=3.5,
sampler_name="euler",
scheduler="simple",
denoise=1.0,
max_shift=1.15,
base_shift=0.5
)
@register_recipe
class FluxSchnellRecipe(FluxRecipeBase):
"""FLUX Schnell quantization recipe."""
@classmethod
def name(cls) -> str:
return "flux_schnell"
def get_default_calib_steps(self) -> int:
return 64
def get_width(self) -> int:
return 1024
def get_height(self) -> int:
return 1024
def get_inference_steps(self) -> int:
return 4
def get_sampler_cfg(self) -> SamplerCFG:
return SamplerCFG(
cfg=1.0,
sampler_name="euler",
scheduler="simple",
denoise=1.0,
max_shift=1.15,
base_shift=0.5
)