mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 23:02:49 +08:00
- Implement VerifierSelectionNode for creating quality verifiers - Implement InferenceScalingNode that wraps KSampler with quality checks - Add VAE decoding during sampling steps for quality assessment - Include quality verification logic using variance and edge detection - Fix noise/latent handling to match ComfyUI patterns - Add comprehensive error handling and logging - Include documentation and debugging notes
440 lines
15 KiB
Python
440 lines
15 KiB
Python
"""
|
|
Inference Scaling Nodes for ComfyUI
|
|
|
|
This module provides nodes for adaptive inference scaling during image generation.
|
|
The system monitors image quality during the generation loop and adjusts sampling
|
|
parameters dynamically based on quality metrics.
|
|
"""
|
|
|
|
from typing import Optional, Callable, Dict, Any
|
|
from typing_extensions import override
|
|
from inspect import cleandoc
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from comfy_api.latest import IO, ComfyExtension
|
|
from comfy.samplers import KSampler
|
|
from comfy.sd import VAE
|
|
from comfy import model_management
|
|
import comfy.sample
|
|
|
|
# Get sampler and scheduler options
|
|
SAMPLER_OPTIONS = KSampler.SAMPLERS
|
|
SCHEDULER_OPTIONS = KSampler.SCHEDULERS
|
|
|
|
|
|
class QualityVerifier:
|
|
"""
|
|
Base class for image quality verification.
|
|
Subclasses should implement specific quality metrics.
|
|
"""
|
|
|
|
def __init__(self, threshold: float = 0.5):
|
|
self.threshold = threshold
|
|
|
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
|
"""
|
|
Verify image quality.
|
|
|
|
Args:
|
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
|
|
|
Returns:
|
|
Tuple of (is_acceptable, quality_score)
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement verify method")
|
|
|
|
def get_quality_score(self, image: torch.Tensor) -> float:
|
|
"""
|
|
Get quality score without threshold check.
|
|
|
|
Args:
|
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
|
|
|
Returns:
|
|
Quality score in range [0, 1]
|
|
"""
|
|
_, score = self.verify(image)
|
|
return score
|
|
|
|
|
|
class SimpleVerifier(QualityVerifier):
|
|
"""
|
|
Simple verifier using basic image statistics.
|
|
Uses variance and edge detection as quality indicators.
|
|
"""
|
|
|
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
|
"""
|
|
Simple quality check based on image variance and structure.
|
|
"""
|
|
# Convert to grayscale for analysis
|
|
if image.shape[-1] == 4: # RGBA
|
|
gray = image[..., :3].mean(dim=-1, keepdim=True)
|
|
else:
|
|
gray = image.mean(dim=-1, keepdim=True)
|
|
|
|
# Calculate variance (higher variance = more detail)
|
|
variance = gray.var().item()
|
|
|
|
# Simple edge detection using Sobel-like operator
|
|
# Convert to numpy for easier processing
|
|
img_np = gray.squeeze().cpu().numpy()
|
|
if len(img_np.shape) == 3:
|
|
img_np = img_np[0] # Take first batch
|
|
|
|
# Calculate gradients
|
|
grad_x = np.gradient(img_np, axis=1)
|
|
grad_y = np.gradient(img_np, axis=0)
|
|
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
|
|
|
|
# Combine metrics (normalize to [0, 1])
|
|
variance_score = min(variance * 10, 1.0) # Scale variance
|
|
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
|
|
|
|
quality_score = (variance_score * 0.6 + edge_score * 0.4)
|
|
|
|
is_acceptable = quality_score >= self.threshold
|
|
|
|
return is_acceptable, quality_score
|
|
|
|
|
|
class VerifierSelectionNode(IO.ComfyNode):
|
|
"""
|
|
Node for selecting or creating an image quality verifier.
|
|
The verifier is used by InferenceScalingNode to judge image quality
|
|
during the generation process.
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="VerifierSelectionNode",
|
|
display_name="Verifier Selection",
|
|
category="inference_scaling",
|
|
description=cleandoc(cls.__doc__ or ""),
|
|
inputs=[
|
|
IO.Combo.Input(
|
|
"verifier_type",
|
|
options=["simple", "custom"],
|
|
default="simple",
|
|
tooltip="Type of quality verifier to use",
|
|
),
|
|
IO.Float.Input(
|
|
"quality_threshold",
|
|
default=0.5,
|
|
min=0.0,
|
|
max=1.0,
|
|
step=0.01,
|
|
tooltip="Minimum quality score threshold (0.0-1.0)",
|
|
),
|
|
],
|
|
outputs=[
|
|
IO.String.Output(), # Verifier identifier (for now, just a string)
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
async def execute(
|
|
cls,
|
|
verifier_type: str = "simple",
|
|
quality_threshold: float = 0.5,
|
|
) -> IO.NodeOutput:
|
|
"""
|
|
Create and return a verifier identifier.
|
|
|
|
In a real implementation, this would create and store the verifier
|
|
instance. For now, we return a serialized identifier.
|
|
"""
|
|
# Create verifier based on type
|
|
if verifier_type == "simple":
|
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
|
else:
|
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
|
|
|
# Store verifier in a way that InferenceScalingNode can access it
|
|
# For now, we'll use a simple approach: store in a module-level dict
|
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
|
if not hasattr(mod, '_verifier_registry'):
|
|
mod._verifier_registry = {}
|
|
|
|
verifier_id = f"verifier_{id(verifier)}"
|
|
mod._verifier_registry[verifier_id] = verifier
|
|
|
|
return IO.NodeOutput(verifier_id)
|
|
|
|
|
|
class InferenceScalingNode(IO.ComfyNode):
|
|
"""
|
|
Main inference scaling node that wraps KSampler with quality monitoring.
|
|
|
|
This node:
|
|
1. Wraps the standard sampling process
|
|
2. Periodically decodes latents to images using VAE during sampling
|
|
3. Uses a verifier to judge image quality
|
|
4. Adjusts sampling parameters (steps, CFG) based on quality
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="InferenceScalingNode",
|
|
display_name="Inference Scaling",
|
|
category="inference_scaling",
|
|
description=cleandoc(cls.__doc__ or ""),
|
|
inputs=[
|
|
IO.String.Input(
|
|
"verifier_id",
|
|
default="",
|
|
tooltip="Verifier identifier from VerifierSelectionNode",
|
|
),
|
|
IO.Model.Input(
|
|
"model",
|
|
tooltip="Model to use for sampling",
|
|
),
|
|
IO.Conditioning.Input(
|
|
"positive",
|
|
tooltip="Positive conditioning",
|
|
),
|
|
IO.Conditioning.Input(
|
|
"negative",
|
|
tooltip="Negative conditioning",
|
|
),
|
|
IO.Latent.Input(
|
|
"latent_image",
|
|
tooltip="Initial latent image",
|
|
),
|
|
IO.Vae.Input(
|
|
"vae",
|
|
tooltip="VAE model for decoding latents during quality checks",
|
|
),
|
|
IO.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xffffffffffffffff,
|
|
control_after_generate=True,
|
|
tooltip="Random seed",
|
|
),
|
|
IO.Int.Input(
|
|
"steps",
|
|
default=20,
|
|
min=1,
|
|
max=1000,
|
|
tooltip="Number of sampling steps",
|
|
),
|
|
IO.Float.Input(
|
|
"cfg",
|
|
default=7.0,
|
|
min=0.0,
|
|
max=30.0,
|
|
step=0.1,
|
|
tooltip="CFG scale",
|
|
),
|
|
IO.Combo.Input(
|
|
"sampler_name",
|
|
options=SAMPLER_OPTIONS,
|
|
default="euler",
|
|
tooltip="Sampler name",
|
|
),
|
|
IO.Combo.Input(
|
|
"scheduler",
|
|
options=SCHEDULER_OPTIONS,
|
|
default="normal",
|
|
tooltip="Scheduler name",
|
|
),
|
|
IO.Int.Input(
|
|
"check_interval",
|
|
default=5,
|
|
min=1,
|
|
max=50,
|
|
tooltip="Check quality every N steps",
|
|
),
|
|
IO.Float.Input(
|
|
"quality_threshold",
|
|
default=0.5,
|
|
min=0.0,
|
|
max=1.0,
|
|
step=0.01,
|
|
tooltip="Quality threshold for early stopping or adjustment",
|
|
),
|
|
IO.Float.Input(
|
|
"scale_factor",
|
|
default=1.2,
|
|
min=0.5,
|
|
max=2.0,
|
|
step=0.1,
|
|
tooltip="Factor to scale steps/CFG when quality is low",
|
|
),
|
|
],
|
|
outputs=[
|
|
IO.Latent.Output(), # Latent output
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
async def execute(
|
|
cls,
|
|
verifier_id: str,
|
|
model: Any,
|
|
positive: Any,
|
|
negative: Any,
|
|
latent_image: dict,
|
|
vae: VAE,
|
|
seed: int = 0,
|
|
steps: int = 20,
|
|
cfg: float = 7.0,
|
|
sampler_name: str = "euler",
|
|
scheduler: str = "normal",
|
|
check_interval: int = 5,
|
|
quality_threshold: float = 0.5,
|
|
scale_factor: float = 1.2,
|
|
) -> IO.NodeOutput:
|
|
"""
|
|
Execute inference scaling sampling.
|
|
|
|
This wraps the standard sampling process and adds quality monitoring.
|
|
"""
|
|
# Get verifier from registry
|
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
|
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
|
|
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
|
|
|
|
verifier = mod._verifier_registry[verifier_id]
|
|
|
|
# Prepare sampling parameters
|
|
device = model_management.get_torch_device()
|
|
|
|
# Extract latent_image tensor from dict
|
|
latent_image_tensor = latent_image["samples"]
|
|
|
|
# Fix empty latent channels if needed
|
|
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
|
|
|
|
# Prepare noise from latent_image and seed (same as common_ksampler does)
|
|
batch_inds = latent_image.get("batch_index", None)
|
|
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
|
|
|
|
# Get noise_mask if present
|
|
noise_mask = latent_image.get("noise_mask", None)
|
|
|
|
# Track quality during sampling
|
|
quality_history = []
|
|
current_cfg = cfg
|
|
|
|
def quality_check_callback(callback_dict: dict):
|
|
"""
|
|
Callback function called during sampling steps.
|
|
Decodes latents, checks quality, and adjusts parameters.
|
|
|
|
Args:
|
|
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
|
|
"""
|
|
nonlocal current_cfg, quality_history
|
|
|
|
step = callback_dict['i']
|
|
denoised = callback_dict['denoised'] # This is the predicted x0
|
|
|
|
# Only check at specified intervals
|
|
if check_interval > 0 and step % check_interval != 0:
|
|
return
|
|
|
|
try:
|
|
# Decode latent to image using VAE
|
|
# denoised is the predicted x0 in latent space [B, C, H, W]
|
|
with torch.no_grad():
|
|
# Ensure denoised is a proper tensor
|
|
if not isinstance(denoised, torch.Tensor):
|
|
return
|
|
|
|
# VAE.decode expects [B, C, H, W] format and handles device management
|
|
# It returns [B, H, W, C] in range [0, 1]
|
|
decoded = vae.decode(denoised)
|
|
|
|
# Ensure decoded is in correct format [B, H, W, C]
|
|
if decoded is None:
|
|
return
|
|
|
|
# Move to CPU for quality checking to avoid GPU memory issues
|
|
decoded_cpu = decoded.cpu()
|
|
|
|
# Check quality
|
|
is_acceptable, quality_score = verifier.verify(decoded_cpu)
|
|
quality_history.append((step, quality_score, is_acceptable))
|
|
|
|
# Log quality for debugging
|
|
import logging
|
|
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
|
|
|
|
# Note: CFG adjustment here won't affect current sampling run
|
|
# as CFG is set at the start. This is for future reference/logging.
|
|
if not is_acceptable and quality_score < quality_threshold:
|
|
# Quality is low - log for analysis
|
|
logging.warning(f"Inference Scaling: Low quality detected at step {step}: {quality_score:.3f} < {quality_threshold}")
|
|
elif is_acceptable and quality_score > quality_threshold * 1.2:
|
|
# Quality is good
|
|
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
|
|
|
|
except Exception as e:
|
|
# If decoding fails, continue sampling (don't break the generation)
|
|
import logging
|
|
import traceback
|
|
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
|
|
# Don't log full traceback in production to avoid spam
|
|
# logging.debug(traceback.format_exc())
|
|
|
|
# Perform sampling with quality monitoring
|
|
try:
|
|
# Use comfy.sample.sample which handles everything properly
|
|
result_samples = comfy.sample.sample(
|
|
model=model,
|
|
noise=noise,
|
|
steps=steps,
|
|
cfg=current_cfg,
|
|
sampler_name=sampler_name,
|
|
scheduler=scheduler,
|
|
positive=positive,
|
|
negative=negative,
|
|
latent_image=latent_image_tensor,
|
|
denoise=1.0,
|
|
disable_noise=False,
|
|
start_step=None,
|
|
last_step=None,
|
|
force_full_denoise=False,
|
|
noise_mask=noise_mask,
|
|
callback=quality_check_callback,
|
|
disable_pbar=False,
|
|
seed=seed,
|
|
)
|
|
|
|
# Return the result as a latent dict (preserve other keys from input)
|
|
result_latent = latent_image.copy()
|
|
result_latent["samples"] = result_samples
|
|
|
|
return IO.NodeOutput(result_latent)
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
|
|
|
|
|
|
class InferenceScalingExtension(ComfyExtension):
|
|
"""
|
|
Extension class that registers inference scaling nodes.
|
|
"""
|
|
|
|
@override
|
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
return [
|
|
VerifierSelectionNode,
|
|
InferenceScalingNode,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> InferenceScalingExtension:
|
|
"""
|
|
Entry point function that ComfyUI calls to load the extension.
|
|
"""
|
|
return InferenceScalingExtension()
|