ComfyUI/comfy_api_nodes/nodes_inference_scaling.py
kasailab c254fa10d3 Add inference scaling nodes with quality monitoring during generation
- Implement VerifierSelectionNode for creating quality verifiers
- Implement InferenceScalingNode that wraps KSampler with quality checks
- Add VAE decoding during sampling steps for quality assessment
- Include quality verification logic using variance and edge detection
- Fix noise/latent handling to match ComfyUI patterns
- Add comprehensive error handling and logging
- Include documentation and debugging notes
2026-04-04 14:55:00 +09:00

440 lines
15 KiB
Python

"""
Inference Scaling Nodes for ComfyUI
This module provides nodes for adaptive inference scaling during image generation.
The system monitors image quality during the generation loop and adjusts sampling
parameters dynamically based on quality metrics.
"""
from typing import Optional, Callable, Dict, Any
from typing_extensions import override
from inspect import cleandoc
import torch
import numpy as np
from PIL import Image
from comfy_api.latest import IO, ComfyExtension
from comfy.samplers import KSampler
from comfy.sd import VAE
from comfy import model_management
import comfy.sample
# Get sampler and scheduler options
SAMPLER_OPTIONS = KSampler.SAMPLERS
SCHEDULER_OPTIONS = KSampler.SCHEDULERS
class QualityVerifier:
"""
Base class for image quality verification.
Subclasses should implement specific quality metrics.
"""
def __init__(self, threshold: float = 0.5):
self.threshold = threshold
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
"""
Verify image quality.
Args:
image: Image tensor [B, H, W, C] in range [0, 1]
Returns:
Tuple of (is_acceptable, quality_score)
"""
raise NotImplementedError("Subclasses must implement verify method")
def get_quality_score(self, image: torch.Tensor) -> float:
"""
Get quality score without threshold check.
Args:
image: Image tensor [B, H, W, C] in range [0, 1]
Returns:
Quality score in range [0, 1]
"""
_, score = self.verify(image)
return score
class SimpleVerifier(QualityVerifier):
"""
Simple verifier using basic image statistics.
Uses variance and edge detection as quality indicators.
"""
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
"""
Simple quality check based on image variance and structure.
"""
# Convert to grayscale for analysis
if image.shape[-1] == 4: # RGBA
gray = image[..., :3].mean(dim=-1, keepdim=True)
else:
gray = image.mean(dim=-1, keepdim=True)
# Calculate variance (higher variance = more detail)
variance = gray.var().item()
# Simple edge detection using Sobel-like operator
# Convert to numpy for easier processing
img_np = gray.squeeze().cpu().numpy()
if len(img_np.shape) == 3:
img_np = img_np[0] # Take first batch
# Calculate gradients
grad_x = np.gradient(img_np, axis=1)
grad_y = np.gradient(img_np, axis=0)
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
# Combine metrics (normalize to [0, 1])
variance_score = min(variance * 10, 1.0) # Scale variance
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
quality_score = (variance_score * 0.6 + edge_score * 0.4)
is_acceptable = quality_score >= self.threshold
return is_acceptable, quality_score
class VerifierSelectionNode(IO.ComfyNode):
"""
Node for selecting or creating an image quality verifier.
The verifier is used by InferenceScalingNode to judge image quality
during the generation process.
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VerifierSelectionNode",
display_name="Verifier Selection",
category="inference_scaling",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
"verifier_type",
options=["simple", "custom"],
default="simple",
tooltip="Type of quality verifier to use",
),
IO.Float.Input(
"quality_threshold",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Minimum quality score threshold (0.0-1.0)",
),
],
outputs=[
IO.String.Output(), # Verifier identifier (for now, just a string)
],
)
@classmethod
async def execute(
cls,
verifier_type: str = "simple",
quality_threshold: float = 0.5,
) -> IO.NodeOutput:
"""
Create and return a verifier identifier.
In a real implementation, this would create and store the verifier
instance. For now, we return a serialized identifier.
"""
# Create verifier based on type
if verifier_type == "simple":
verifier = SimpleVerifier(threshold=quality_threshold)
else:
verifier = SimpleVerifier(threshold=quality_threshold)
# Store verifier in a way that InferenceScalingNode can access it
# For now, we'll use a simple approach: store in a module-level dict
import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry'):
mod._verifier_registry = {}
verifier_id = f"verifier_{id(verifier)}"
mod._verifier_registry[verifier_id] = verifier
return IO.NodeOutput(verifier_id)
class InferenceScalingNode(IO.ComfyNode):
"""
Main inference scaling node that wraps KSampler with quality monitoring.
This node:
1. Wraps the standard sampling process
2. Periodically decodes latents to images using VAE during sampling
3. Uses a verifier to judge image quality
4. Adjusts sampling parameters (steps, CFG) based on quality
"""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="InferenceScalingNode",
display_name="Inference Scaling",
category="inference_scaling",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"verifier_id",
default="",
tooltip="Verifier identifier from VerifierSelectionNode",
),
IO.Model.Input(
"model",
tooltip="Model to use for sampling",
),
IO.Conditioning.Input(
"positive",
tooltip="Positive conditioning",
),
IO.Conditioning.Input(
"negative",
tooltip="Negative conditioning",
),
IO.Latent.Input(
"latent_image",
tooltip="Initial latent image",
),
IO.Vae.Input(
"vae",
tooltip="VAE model for decoding latents during quality checks",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xffffffffffffffff,
control_after_generate=True,
tooltip="Random seed",
),
IO.Int.Input(
"steps",
default=20,
min=1,
max=1000,
tooltip="Number of sampling steps",
),
IO.Float.Input(
"cfg",
default=7.0,
min=0.0,
max=30.0,
step=0.1,
tooltip="CFG scale",
),
IO.Combo.Input(
"sampler_name",
options=SAMPLER_OPTIONS,
default="euler",
tooltip="Sampler name",
),
IO.Combo.Input(
"scheduler",
options=SCHEDULER_OPTIONS,
default="normal",
tooltip="Scheduler name",
),
IO.Int.Input(
"check_interval",
default=5,
min=1,
max=50,
tooltip="Check quality every N steps",
),
IO.Float.Input(
"quality_threshold",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
tooltip="Quality threshold for early stopping or adjustment",
),
IO.Float.Input(
"scale_factor",
default=1.2,
min=0.5,
max=2.0,
step=0.1,
tooltip="Factor to scale steps/CFG when quality is low",
),
],
outputs=[
IO.Latent.Output(), # Latent output
],
)
@classmethod
async def execute(
cls,
verifier_id: str,
model: Any,
positive: Any,
negative: Any,
latent_image: dict,
vae: VAE,
seed: int = 0,
steps: int = 20,
cfg: float = 7.0,
sampler_name: str = "euler",
scheduler: str = "normal",
check_interval: int = 5,
quality_threshold: float = 0.5,
scale_factor: float = 1.2,
) -> IO.NodeOutput:
"""
Execute inference scaling sampling.
This wraps the standard sampling process and adds quality monitoring.
"""
# Get verifier from registry
import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
verifier = mod._verifier_registry[verifier_id]
# Prepare sampling parameters
device = model_management.get_torch_device()
# Extract latent_image tensor from dict
latent_image_tensor = latent_image["samples"]
# Fix empty latent channels if needed
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
# Prepare noise from latent_image and seed (same as common_ksampler does)
batch_inds = latent_image.get("batch_index", None)
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
# Get noise_mask if present
noise_mask = latent_image.get("noise_mask", None)
# Track quality during sampling
quality_history = []
current_cfg = cfg
def quality_check_callback(callback_dict: dict):
"""
Callback function called during sampling steps.
Decodes latents, checks quality, and adjusts parameters.
Args:
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
"""
nonlocal current_cfg, quality_history
step = callback_dict['i']
denoised = callback_dict['denoised'] # This is the predicted x0
# Only check at specified intervals
if check_interval > 0 and step % check_interval != 0:
return
try:
# Decode latent to image using VAE
# denoised is the predicted x0 in latent space [B, C, H, W]
with torch.no_grad():
# Ensure denoised is a proper tensor
if not isinstance(denoised, torch.Tensor):
return
# VAE.decode expects [B, C, H, W] format and handles device management
# It returns [B, H, W, C] in range [0, 1]
decoded = vae.decode(denoised)
# Ensure decoded is in correct format [B, H, W, C]
if decoded is None:
return
# Move to CPU for quality checking to avoid GPU memory issues
decoded_cpu = decoded.cpu()
# Check quality
is_acceptable, quality_score = verifier.verify(decoded_cpu)
quality_history.append((step, quality_score, is_acceptable))
# Log quality for debugging
import logging
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
# Note: CFG adjustment here won't affect current sampling run
# as CFG is set at the start. This is for future reference/logging.
if not is_acceptable and quality_score < quality_threshold:
# Quality is low - log for analysis
logging.warning(f"Inference Scaling: Low quality detected at step {step}: {quality_score:.3f} < {quality_threshold}")
elif is_acceptable and quality_score > quality_threshold * 1.2:
# Quality is good
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
except Exception as e:
# If decoding fails, continue sampling (don't break the generation)
import logging
import traceback
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
# Don't log full traceback in production to avoid spam
# logging.debug(traceback.format_exc())
# Perform sampling with quality monitoring
try:
# Use comfy.sample.sample which handles everything properly
result_samples = comfy.sample.sample(
model=model,
noise=noise,
steps=steps,
cfg=current_cfg,
sampler_name=sampler_name,
scheduler=scheduler,
positive=positive,
negative=negative,
latent_image=latent_image_tensor,
denoise=1.0,
disable_noise=False,
start_step=None,
last_step=None,
force_full_denoise=False,
noise_mask=noise_mask,
callback=quality_check_callback,
disable_pbar=False,
seed=seed,
)
# Return the result as a latent dict (preserve other keys from input)
result_latent = latent_image.copy()
result_latent["samples"] = result_samples
return IO.NodeOutput(result_latent)
except Exception as e:
import traceback
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
class InferenceScalingExtension(ComfyExtension):
"""
Extension class that registers inference scaling nodes.
"""
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
VerifierSelectionNode,
InferenceScalingNode,
]
async def comfy_entrypoint() -> InferenceScalingExtension:
"""
Entry point function that ComfyUI calls to load the extension.
"""
return InferenceScalingExtension()