mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
278 lines
8.4 KiB
Python
278 lines
8.4 KiB
Python
"""
|
|
FLUX model quantization recipes.
|
|
|
|
Defines FluxDevRecipe and FluxSchnellRecipe for quantizing FLUX models.
|
|
"""
|
|
import logging
|
|
import comfy.sd
|
|
import folder_paths
|
|
from dataclasses import dataclass
|
|
from typing import Tuple, Callable
|
|
|
|
from comfy_extras.nodes_custom_sampler import (
|
|
BasicGuider,
|
|
SamplerCustomAdvanced,
|
|
BasicScheduler,
|
|
RandomNoise,
|
|
KSamplerSelect,
|
|
)
|
|
from comfy_extras.nodes_flux import FluxGuidance
|
|
from comfy_extras.nodes_model_advanced import ModelSamplingFlux
|
|
from comfy_extras.nodes_sd3 import EmptySD3LatentImage
|
|
from nodes import CLIPTextEncode
|
|
|
|
from . import register_recipe
|
|
from .base import ModelRecipe
|
|
|
|
@dataclass
|
|
class SamplerCFG:
|
|
cfg: float
|
|
sampler_name: str
|
|
scheduler: str
|
|
denoise: float
|
|
max_shift: float
|
|
base_shift: float
|
|
|
|
class FluxT2IPipe:
|
|
def __init__(
|
|
self,
|
|
model,
|
|
clip,
|
|
batch_size,
|
|
width=1024,
|
|
height=1024,
|
|
seed=0,
|
|
sampler_cfg: SamplerCFG = None,
|
|
device="cuda",
|
|
) -> None:
|
|
self.clip = clip
|
|
self.clip_node = CLIPTextEncode()
|
|
self.batch_size = batch_size
|
|
self.width = width
|
|
self.height = height
|
|
self.max_shift = sampler_cfg.max_shift
|
|
self.base_shift = sampler_cfg.base_shift
|
|
self.device = device
|
|
|
|
self.seed = seed
|
|
self.cfg = sampler_cfg.cfg
|
|
self.scheduler_name = sampler_cfg.scheduler
|
|
self.denoise = sampler_cfg.denoise
|
|
|
|
(self.model,) = ModelSamplingFlux().patch(
|
|
model, self.max_shift, self.base_shift, self.width, self.height
|
|
)
|
|
(self.ksampler,) = KSamplerSelect().get_sampler(sampler_cfg.sampler_name)
|
|
self.latent_node = EmptySD3LatentImage()
|
|
self.guidance = FluxGuidance()
|
|
self.sampler = SamplerCustomAdvanced()
|
|
self.scheduler_node = BasicScheduler()
|
|
self.guider = BasicGuider()
|
|
self.noise_generator = RandomNoise()
|
|
|
|
def __call__(self, num_inference_steps, positive_prompt, *args, **kwargs):
|
|
(positive,) = self.clip_node.encode(self.clip, positive_prompt)
|
|
(latent_image,) = self.latent_node.generate(
|
|
self.width, self.height, self.batch_size
|
|
)
|
|
(noise,) = self.noise_generator.get_noise(self.seed)
|
|
|
|
(conditioning,) = self.guidance.append(positive, self.cfg)
|
|
(sigmas,) = self.scheduler_node.get_sigmas(
|
|
self.model, self.scheduler_name, num_inference_steps, self.denoise
|
|
)
|
|
(guider,) = self.guider.get_guider(self.model, conditioning)
|
|
|
|
out, denoised_out = self.sampler.sample(
|
|
noise, guider, self.ksampler, sigmas, latent_image
|
|
)
|
|
|
|
return out["samples"]
|
|
|
|
class FluxRecipeBase(ModelRecipe):
|
|
"""Base class for FLUX model quantization recipes."""
|
|
|
|
@classmethod
|
|
def add_model_args(cls, parser):
|
|
"""Add FLUX-specific CLI arguments."""
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument(
|
|
"--ckpt_path",
|
|
help="Path to full FLUX checkpoint (includes diffusion model + CLIP + T5)"
|
|
)
|
|
group.add_argument(
|
|
"--unet_path",
|
|
help="Path to FLUX diffusion model only (requires --clip_path and --t5_path)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--clip_path",
|
|
help="Path to CLIP text encoder (required with --unet_path)"
|
|
)
|
|
parser.add_argument(
|
|
"--t5_path",
|
|
help="Path to T5 text encoder (required with --unet_path)"
|
|
)
|
|
|
|
def __init__(self, args):
|
|
"""Initialize FLUX recipe with CLI args."""
|
|
self.args = args
|
|
|
|
# Validate args
|
|
if hasattr(args, 'unet_path') and args.unet_path:
|
|
if not args.clip_path or not args.t5_path:
|
|
raise ValueError("--unet_path requires both --clip_path and --t5_path")
|
|
|
|
def load_model(self) -> Tuple:
|
|
"""Load FLUX model, CLIP, and VAE."""
|
|
if hasattr(self.args, 'ckpt_path') and self.args.ckpt_path:
|
|
# Load from full checkpoint
|
|
logging.info(f"Loading full checkpoint from {self.args.ckpt_path}")
|
|
model_patcher, clip, vae, _ = comfy.sd.load_checkpoint_guess_config(
|
|
self.args.ckpt_path,
|
|
output_vae=True,
|
|
output_clip=True,
|
|
embedding_directory=None
|
|
)
|
|
else:
|
|
# Load from separate files
|
|
logging.info(f"Loading diffusion model from {self.args.unet_path}")
|
|
model_options = {}
|
|
clip_type = comfy.sd.CLIPType.FLUX
|
|
|
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", self.args.clip_path)
|
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", self.args.t5_path)
|
|
|
|
model_patcher = comfy.sd.load_diffusion_model(
|
|
self.args.unet_path,
|
|
model_options=model_options
|
|
)
|
|
clip = comfy.sd.load_clip(
|
|
ckpt_paths=[clip_path1, clip_path2],
|
|
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
|
clip_type=clip_type,
|
|
model_options=model_options
|
|
)
|
|
vae = None # Not needed for calibration
|
|
|
|
return model_patcher, clip, vae
|
|
|
|
def create_calibration_pipeline(self, model_components):
|
|
"""Create FluxT2IPipe for calibration."""
|
|
model_patcher, clip, vae = model_components
|
|
|
|
return FluxT2IPipe(
|
|
model=model_patcher,
|
|
clip=clip,
|
|
batch_size=1,
|
|
width=self.get_width(),
|
|
height=self.get_height(),
|
|
seed=42,
|
|
sampler_cfg=self.get_sampler_cfg(),
|
|
device="cuda"
|
|
)
|
|
|
|
def get_forward_loop(self, calib_pipeline, dataloader) -> Callable:
|
|
"""
|
|
Return forward_loop for ModelOptimizer calibration.
|
|
|
|
Iterates through the dataloader and runs full sampling
|
|
for each prompt to collect activation statistics.
|
|
"""
|
|
num_steps = self.get_inference_steps()
|
|
|
|
def forward_loop():
|
|
for i, prompt in enumerate(dataloader):
|
|
# Dataloader returns batches, extract first element
|
|
prompt_text = prompt[0] if isinstance(prompt, (list, tuple)) else prompt
|
|
|
|
logging.debug(f"Calibration step {i+1}: '{prompt_text[:50]}...'")
|
|
|
|
try:
|
|
# Run full sampling pipeline
|
|
calib_pipeline(num_steps, prompt_text)
|
|
except Exception as e:
|
|
logging.warning(f"Calibration step {i+1} failed: {e}")
|
|
# Continue with next prompt
|
|
|
|
return forward_loop
|
|
|
|
# Abstract methods for variants to implement
|
|
def get_width(self) -> int:
|
|
"""Image width for calibration."""
|
|
raise NotImplementedError
|
|
|
|
def get_height(self) -> int:
|
|
"""Image height for calibration."""
|
|
raise NotImplementedError
|
|
|
|
def get_sampler_cfg(self) -> SamplerCFG:
|
|
"""Sampler configuration."""
|
|
raise NotImplementedError
|
|
|
|
def get_inference_steps(self) -> int:
|
|
"""Number of sampling steps per calibration iteration."""
|
|
raise NotImplementedError
|
|
|
|
|
|
@register_recipe
|
|
class FluxDevRecipe(FluxRecipeBase):
|
|
"""FLUX Dev quantization recipe."""
|
|
|
|
@classmethod
|
|
def name(cls) -> str:
|
|
return "flux_dev"
|
|
|
|
def get_default_calib_steps(self) -> int:
|
|
return 128
|
|
|
|
def get_width(self) -> int:
|
|
return 1024
|
|
|
|
def get_height(self) -> int:
|
|
return 1024
|
|
|
|
def get_inference_steps(self) -> int:
|
|
return 30
|
|
|
|
def get_sampler_cfg(self) -> SamplerCFG:
|
|
return SamplerCFG(
|
|
cfg=3.5,
|
|
sampler_name="euler",
|
|
scheduler="simple",
|
|
denoise=1.0,
|
|
max_shift=1.15,
|
|
base_shift=0.5
|
|
)
|
|
|
|
|
|
@register_recipe
|
|
class FluxSchnellRecipe(FluxRecipeBase):
|
|
"""FLUX Schnell quantization recipe."""
|
|
|
|
@classmethod
|
|
def name(cls) -> str:
|
|
return "flux_schnell"
|
|
|
|
def get_default_calib_steps(self) -> int:
|
|
return 64
|
|
|
|
def get_width(self) -> int:
|
|
return 1024
|
|
|
|
def get_height(self) -> int:
|
|
return 1024
|
|
|
|
def get_inference_steps(self) -> int:
|
|
return 4
|
|
|
|
def get_sampler_cfg(self) -> SamplerCFG:
|
|
return SamplerCFG(
|
|
cfg=1.0,
|
|
sampler_name="euler",
|
|
scheduler="simple",
|
|
denoise=1.0,
|
|
max_shift=1.15,
|
|
base_shift=0.5
|
|
)
|