diff --git a/comfy_api_nodes/nodes_inference_scaling.py b/comfy_api_nodes/nodes_inference_scaling.py index c0acc531c..fb67fc089 100644 --- a/comfy_api_nodes/nodes_inference_scaling.py +++ b/comfy_api_nodes/nodes_inference_scaling.py @@ -6,12 +6,11 @@ The system monitors image quality during the generation loop and adjusts samplin parameters dynamically based on quality metrics. """ -from typing import Optional, Callable, Dict, Any +from typing import Any from typing_extensions import override from inspect import cleandoc import torch import numpy as np -from PIL import Image from comfy_api.latest import IO, ComfyExtension from comfy.samplers import KSampler @@ -29,29 +28,29 @@ class QualityVerifier: Base class for image quality verification. Subclasses should implement specific quality metrics. """ - + def __init__(self, threshold: float = 0.5): self.threshold = threshold - + def verify(self, image: torch.Tensor) -> tuple[bool, float]: """ Verify image quality. - + Args: image: Image tensor [B, H, W, C] in range [0, 1] - + Returns: Tuple of (is_acceptable, quality_score) """ raise NotImplementedError("Subclasses must implement verify method") - + def get_quality_score(self, image: torch.Tensor) -> float: """ Get quality score without threshold check. - + Args: image: Image tensor [B, H, W, C] in range [0, 1] - + Returns: Quality score in range [0, 1] """ @@ -64,7 +63,7 @@ class SimpleVerifier(QualityVerifier): Simple verifier using basic image statistics. Uses variance and edge detection as quality indicators. """ - + def verify(self, image: torch.Tensor) -> tuple[bool, float]: """ Simple quality check based on image variance and structure. @@ -74,29 +73,29 @@ class SimpleVerifier(QualityVerifier): gray = image[..., :3].mean(dim=-1, keepdim=True) else: gray = image.mean(dim=-1, keepdim=True) - + # Calculate variance (higher variance = more detail) variance = gray.var().item() - + # Simple edge detection using Sobel-like operator # Convert to numpy for easier processing img_np = gray.squeeze().cpu().numpy() if len(img_np.shape) == 3: img_np = img_np[0] # Take first batch - + # Calculate gradients grad_x = np.gradient(img_np, axis=1) grad_y = np.gradient(img_np, axis=0) edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean() - + # Combine metrics (normalize to [0, 1]) variance_score = min(variance * 10, 1.0) # Scale variance edge_score = min(edge_strength * 5, 1.0) # Scale edge strength - + quality_score = (variance_score * 0.6 + edge_score * 0.4) - + is_acceptable = quality_score >= self.threshold - + return is_acceptable, quality_score @@ -143,7 +142,7 @@ class VerifierSelectionNode(IO.ComfyNode): ) -> IO.NodeOutput: """ Create and return a verifier identifier. - + In a real implementation, this would create and store the verifier instance. For now, we return a serialized identifier. """ @@ -152,23 +151,23 @@ class VerifierSelectionNode(IO.ComfyNode): verifier = SimpleVerifier(threshold=quality_threshold) else: verifier = SimpleVerifier(threshold=quality_threshold) - + # Store verifier in a way that InferenceScalingNode can access it # For now, we'll use a simple approach: store in a module-level dict import comfy_api_nodes.nodes_inference_scaling as mod if not hasattr(mod, '_verifier_registry'): mod._verifier_registry = {} - + verifier_id = f"verifier_{id(verifier)}" mod._verifier_registry[verifier_id] = verifier - + return IO.NodeOutput(verifier_id) class InferenceScalingNode(IO.ComfyNode): """ Main inference scaling node that wraps KSampler with quality monitoring. - + This node: 1. Wraps the standard sampling process 2. Periodically decodes latents to images using VAE during sampling @@ -293,53 +292,53 @@ class InferenceScalingNode(IO.ComfyNode): ) -> IO.NodeOutput: """ Execute inference scaling sampling. - + This wraps the standard sampling process and adds quality monitoring. """ # Get verifier from registry import comfy_api_nodes.nodes_inference_scaling as mod if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry: raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.") - + verifier = mod._verifier_registry[verifier_id] - + # Prepare sampling parameters - device = model_management.get_torch_device() - + model_management.get_torch_device() + # Extract latent_image tensor from dict latent_image_tensor = latent_image["samples"] - + # Fix empty latent channels if needed latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor) - + # Prepare noise from latent_image and seed (same as common_ksampler does) batch_inds = latent_image.get("batch_index", None) noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds) - + # Get noise_mask if present noise_mask = latent_image.get("noise_mask", None) - + # Track quality during sampling quality_history = [] current_cfg = cfg - + def quality_check_callback(callback_dict: dict): """ Callback function called during sampling steps. Decodes latents, checks quality, and adjusts parameters. - + Args: callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised' """ nonlocal current_cfg, quality_history - + step = callback_dict['i'] denoised = callback_dict['denoised'] # This is the predicted x0 - + # Only check at specified intervals if check_interval > 0 and step % check_interval != 0: return - + try: # Decode latent to image using VAE # denoised is the predicted x0 in latent space [B, C, H, W] @@ -347,26 +346,26 @@ class InferenceScalingNode(IO.ComfyNode): # Ensure denoised is a proper tensor if not isinstance(denoised, torch.Tensor): return - + # VAE.decode expects [B, C, H, W] format and handles device management # It returns [B, H, W, C] in range [0, 1] decoded = vae.decode(denoised) - + # Ensure decoded is in correct format [B, H, W, C] if decoded is None: return - + # Move to CPU for quality checking to avoid GPU memory issues decoded_cpu = decoded.cpu() - + # Check quality is_acceptable, quality_score = verifier.verify(decoded_cpu) quality_history.append((step, quality_score, is_acceptable)) - + # Log quality for debugging import logging logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}") - + # Note: CFG adjustment here won't affect current sampling run # as CFG is set at the start. This is for future reference/logging. if not is_acceptable and quality_score < quality_threshold: @@ -375,15 +374,14 @@ class InferenceScalingNode(IO.ComfyNode): elif is_acceptable and quality_score > quality_threshold * 1.2: # Quality is good logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}") - + except Exception as e: # If decoding fails, continue sampling (don't break the generation) import logging - import traceback logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}") # Don't log full traceback in production to avoid spam # logging.debug(traceback.format_exc()) - + # Perform sampling with quality monitoring try: # Use comfy.sample.sample which handles everything properly @@ -407,13 +405,13 @@ class InferenceScalingNode(IO.ComfyNode): disable_pbar=False, seed=seed, ) - + # Return the result as a latent dict (preserve other keys from input) result_latent = latent_image.copy() result_latent["samples"] = result_samples - + return IO.NodeOutput(result_latent) - + except Exception as e: import traceback raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}") @@ -423,7 +421,7 @@ class InferenceScalingExtension(ComfyExtension): """ Extension class that registers inference scaling nodes. """ - + @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [