mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 14:52:46 +08:00
fix: resolve ruff linting errors in nodes_inference_scaling
- Remove unused imports: Optional, Callable, Dict, PIL.Image, traceback - Remove trailing whitespace on blank lines (W293) - Remove unused local variable: device
This commit is contained in:
parent
74442533da
commit
c5ce3e1375
@ -6,12 +6,11 @@ The system monitors image quality during the generation loop and adjusts samplin
|
||||
parameters dynamically based on quality metrics.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
from typing import Any
|
||||
from typing_extensions import override
|
||||
from inspect import cleandoc
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy.samplers import KSampler
|
||||
@ -29,29 +28,29 @@ class QualityVerifier:
|
||||
Base class for image quality verification.
|
||||
Subclasses should implement specific quality metrics.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, threshold: float = 0.5):
|
||||
self.threshold = threshold
|
||||
|
||||
|
||||
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||
"""
|
||||
Verify image quality.
|
||||
|
||||
|
||||
Args:
|
||||
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (is_acceptable, quality_score)
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement verify method")
|
||||
|
||||
|
||||
def get_quality_score(self, image: torch.Tensor) -> float:
|
||||
"""
|
||||
Get quality score without threshold check.
|
||||
|
||||
|
||||
Args:
|
||||
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||
|
||||
|
||||
Returns:
|
||||
Quality score in range [0, 1]
|
||||
"""
|
||||
@ -64,7 +63,7 @@ class SimpleVerifier(QualityVerifier):
|
||||
Simple verifier using basic image statistics.
|
||||
Uses variance and edge detection as quality indicators.
|
||||
"""
|
||||
|
||||
|
||||
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||
"""
|
||||
Simple quality check based on image variance and structure.
|
||||
@ -74,29 +73,29 @@ class SimpleVerifier(QualityVerifier):
|
||||
gray = image[..., :3].mean(dim=-1, keepdim=True)
|
||||
else:
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
# Calculate variance (higher variance = more detail)
|
||||
variance = gray.var().item()
|
||||
|
||||
|
||||
# Simple edge detection using Sobel-like operator
|
||||
# Convert to numpy for easier processing
|
||||
img_np = gray.squeeze().cpu().numpy()
|
||||
if len(img_np.shape) == 3:
|
||||
img_np = img_np[0] # Take first batch
|
||||
|
||||
|
||||
# Calculate gradients
|
||||
grad_x = np.gradient(img_np, axis=1)
|
||||
grad_y = np.gradient(img_np, axis=0)
|
||||
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
|
||||
|
||||
|
||||
# Combine metrics (normalize to [0, 1])
|
||||
variance_score = min(variance * 10, 1.0) # Scale variance
|
||||
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
|
||||
|
||||
|
||||
quality_score = (variance_score * 0.6 + edge_score * 0.4)
|
||||
|
||||
|
||||
is_acceptable = quality_score >= self.threshold
|
||||
|
||||
|
||||
return is_acceptable, quality_score
|
||||
|
||||
|
||||
@ -143,7 +142,7 @@ class VerifierSelectionNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
"""
|
||||
Create and return a verifier identifier.
|
||||
|
||||
|
||||
In a real implementation, this would create and store the verifier
|
||||
instance. For now, we return a serialized identifier.
|
||||
"""
|
||||
@ -152,23 +151,23 @@ class VerifierSelectionNode(IO.ComfyNode):
|
||||
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||
else:
|
||||
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||
|
||||
|
||||
# Store verifier in a way that InferenceScalingNode can access it
|
||||
# For now, we'll use a simple approach: store in a module-level dict
|
||||
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||
if not hasattr(mod, '_verifier_registry'):
|
||||
mod._verifier_registry = {}
|
||||
|
||||
|
||||
verifier_id = f"verifier_{id(verifier)}"
|
||||
mod._verifier_registry[verifier_id] = verifier
|
||||
|
||||
|
||||
return IO.NodeOutput(verifier_id)
|
||||
|
||||
|
||||
class InferenceScalingNode(IO.ComfyNode):
|
||||
"""
|
||||
Main inference scaling node that wraps KSampler with quality monitoring.
|
||||
|
||||
|
||||
This node:
|
||||
1. Wraps the standard sampling process
|
||||
2. Periodically decodes latents to images using VAE during sampling
|
||||
@ -293,53 +292,53 @@ class InferenceScalingNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
"""
|
||||
Execute inference scaling sampling.
|
||||
|
||||
|
||||
This wraps the standard sampling process and adds quality monitoring.
|
||||
"""
|
||||
# Get verifier from registry
|
||||
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
|
||||
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
|
||||
|
||||
|
||||
verifier = mod._verifier_registry[verifier_id]
|
||||
|
||||
|
||||
# Prepare sampling parameters
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
model_management.get_torch_device()
|
||||
|
||||
# Extract latent_image tensor from dict
|
||||
latent_image_tensor = latent_image["samples"]
|
||||
|
||||
|
||||
# Fix empty latent channels if needed
|
||||
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
|
||||
|
||||
|
||||
# Prepare noise from latent_image and seed (same as common_ksampler does)
|
||||
batch_inds = latent_image.get("batch_index", None)
|
||||
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
|
||||
|
||||
|
||||
# Get noise_mask if present
|
||||
noise_mask = latent_image.get("noise_mask", None)
|
||||
|
||||
|
||||
# Track quality during sampling
|
||||
quality_history = []
|
||||
current_cfg = cfg
|
||||
|
||||
|
||||
def quality_check_callback(callback_dict: dict):
|
||||
"""
|
||||
Callback function called during sampling steps.
|
||||
Decodes latents, checks quality, and adjusts parameters.
|
||||
|
||||
|
||||
Args:
|
||||
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
|
||||
"""
|
||||
nonlocal current_cfg, quality_history
|
||||
|
||||
|
||||
step = callback_dict['i']
|
||||
denoised = callback_dict['denoised'] # This is the predicted x0
|
||||
|
||||
|
||||
# Only check at specified intervals
|
||||
if check_interval > 0 and step % check_interval != 0:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# Decode latent to image using VAE
|
||||
# denoised is the predicted x0 in latent space [B, C, H, W]
|
||||
@ -347,26 +346,26 @@ class InferenceScalingNode(IO.ComfyNode):
|
||||
# Ensure denoised is a proper tensor
|
||||
if not isinstance(denoised, torch.Tensor):
|
||||
return
|
||||
|
||||
|
||||
# VAE.decode expects [B, C, H, W] format and handles device management
|
||||
# It returns [B, H, W, C] in range [0, 1]
|
||||
decoded = vae.decode(denoised)
|
||||
|
||||
|
||||
# Ensure decoded is in correct format [B, H, W, C]
|
||||
if decoded is None:
|
||||
return
|
||||
|
||||
|
||||
# Move to CPU for quality checking to avoid GPU memory issues
|
||||
decoded_cpu = decoded.cpu()
|
||||
|
||||
|
||||
# Check quality
|
||||
is_acceptable, quality_score = verifier.verify(decoded_cpu)
|
||||
quality_history.append((step, quality_score, is_acceptable))
|
||||
|
||||
|
||||
# Log quality for debugging
|
||||
import logging
|
||||
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
|
||||
|
||||
|
||||
# Note: CFG adjustment here won't affect current sampling run
|
||||
# as CFG is set at the start. This is for future reference/logging.
|
||||
if not is_acceptable and quality_score < quality_threshold:
|
||||
@ -375,15 +374,14 @@ class InferenceScalingNode(IO.ComfyNode):
|
||||
elif is_acceptable and quality_score > quality_threshold * 1.2:
|
||||
# Quality is good
|
||||
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# If decoding fails, continue sampling (don't break the generation)
|
||||
import logging
|
||||
import traceback
|
||||
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
|
||||
# Don't log full traceback in production to avoid spam
|
||||
# logging.debug(traceback.format_exc())
|
||||
|
||||
|
||||
# Perform sampling with quality monitoring
|
||||
try:
|
||||
# Use comfy.sample.sample which handles everything properly
|
||||
@ -407,13 +405,13 @@ class InferenceScalingNode(IO.ComfyNode):
|
||||
disable_pbar=False,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
# Return the result as a latent dict (preserve other keys from input)
|
||||
result_latent = latent_image.copy()
|
||||
result_latent["samples"] = result_samples
|
||||
|
||||
|
||||
return IO.NodeOutput(result_latent)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
|
||||
@ -423,7 +421,7 @@ class InferenceScalingExtension(ComfyExtension):
|
||||
"""
|
||||
Extension class that registers inference scaling nodes.
|
||||
"""
|
||||
|
||||
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user