fix: resolve ruff linting errors in nodes_inference_scaling

- Remove unused imports: Optional, Callable, Dict, PIL.Image, traceback
- Remove trailing whitespace on blank lines (W293)
- Remove unused local variable: device
This commit is contained in:
unknown 2026-04-18 15:31:56 +09:00
parent 74442533da
commit c5ce3e1375

View File

@ -6,12 +6,11 @@ The system monitors image quality during the generation loop and adjusts samplin
parameters dynamically based on quality metrics. parameters dynamically based on quality metrics.
""" """
from typing import Optional, Callable, Dict, Any from typing import Any
from typing_extensions import override from typing_extensions import override
from inspect import cleandoc from inspect import cleandoc
import torch import torch
import numpy as np import numpy as np
from PIL import Image
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy.samplers import KSampler from comfy.samplers import KSampler
@ -29,29 +28,29 @@ class QualityVerifier:
Base class for image quality verification. Base class for image quality verification.
Subclasses should implement specific quality metrics. Subclasses should implement specific quality metrics.
""" """
def __init__(self, threshold: float = 0.5): def __init__(self, threshold: float = 0.5):
self.threshold = threshold self.threshold = threshold
def verify(self, image: torch.Tensor) -> tuple[bool, float]: def verify(self, image: torch.Tensor) -> tuple[bool, float]:
""" """
Verify image quality. Verify image quality.
Args: Args:
image: Image tensor [B, H, W, C] in range [0, 1] image: Image tensor [B, H, W, C] in range [0, 1]
Returns: Returns:
Tuple of (is_acceptable, quality_score) Tuple of (is_acceptable, quality_score)
""" """
raise NotImplementedError("Subclasses must implement verify method") raise NotImplementedError("Subclasses must implement verify method")
def get_quality_score(self, image: torch.Tensor) -> float: def get_quality_score(self, image: torch.Tensor) -> float:
""" """
Get quality score without threshold check. Get quality score without threshold check.
Args: Args:
image: Image tensor [B, H, W, C] in range [0, 1] image: Image tensor [B, H, W, C] in range [0, 1]
Returns: Returns:
Quality score in range [0, 1] Quality score in range [0, 1]
""" """
@ -64,7 +63,7 @@ class SimpleVerifier(QualityVerifier):
Simple verifier using basic image statistics. Simple verifier using basic image statistics.
Uses variance and edge detection as quality indicators. Uses variance and edge detection as quality indicators.
""" """
def verify(self, image: torch.Tensor) -> tuple[bool, float]: def verify(self, image: torch.Tensor) -> tuple[bool, float]:
""" """
Simple quality check based on image variance and structure. Simple quality check based on image variance and structure.
@ -74,29 +73,29 @@ class SimpleVerifier(QualityVerifier):
gray = image[..., :3].mean(dim=-1, keepdim=True) gray = image[..., :3].mean(dim=-1, keepdim=True)
else: else:
gray = image.mean(dim=-1, keepdim=True) gray = image.mean(dim=-1, keepdim=True)
# Calculate variance (higher variance = more detail) # Calculate variance (higher variance = more detail)
variance = gray.var().item() variance = gray.var().item()
# Simple edge detection using Sobel-like operator # Simple edge detection using Sobel-like operator
# Convert to numpy for easier processing # Convert to numpy for easier processing
img_np = gray.squeeze().cpu().numpy() img_np = gray.squeeze().cpu().numpy()
if len(img_np.shape) == 3: if len(img_np.shape) == 3:
img_np = img_np[0] # Take first batch img_np = img_np[0] # Take first batch
# Calculate gradients # Calculate gradients
grad_x = np.gradient(img_np, axis=1) grad_x = np.gradient(img_np, axis=1)
grad_y = np.gradient(img_np, axis=0) grad_y = np.gradient(img_np, axis=0)
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean() edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
# Combine metrics (normalize to [0, 1]) # Combine metrics (normalize to [0, 1])
variance_score = min(variance * 10, 1.0) # Scale variance variance_score = min(variance * 10, 1.0) # Scale variance
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
quality_score = (variance_score * 0.6 + edge_score * 0.4) quality_score = (variance_score * 0.6 + edge_score * 0.4)
is_acceptable = quality_score >= self.threshold is_acceptable = quality_score >= self.threshold
return is_acceptable, quality_score return is_acceptable, quality_score
@ -143,7 +142,7 @@ class VerifierSelectionNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
""" """
Create and return a verifier identifier. Create and return a verifier identifier.
In a real implementation, this would create and store the verifier In a real implementation, this would create and store the verifier
instance. For now, we return a serialized identifier. instance. For now, we return a serialized identifier.
""" """
@ -152,23 +151,23 @@ class VerifierSelectionNode(IO.ComfyNode):
verifier = SimpleVerifier(threshold=quality_threshold) verifier = SimpleVerifier(threshold=quality_threshold)
else: else:
verifier = SimpleVerifier(threshold=quality_threshold) verifier = SimpleVerifier(threshold=quality_threshold)
# Store verifier in a way that InferenceScalingNode can access it # Store verifier in a way that InferenceScalingNode can access it
# For now, we'll use a simple approach: store in a module-level dict # For now, we'll use a simple approach: store in a module-level dict
import comfy_api_nodes.nodes_inference_scaling as mod import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry'): if not hasattr(mod, '_verifier_registry'):
mod._verifier_registry = {} mod._verifier_registry = {}
verifier_id = f"verifier_{id(verifier)}" verifier_id = f"verifier_{id(verifier)}"
mod._verifier_registry[verifier_id] = verifier mod._verifier_registry[verifier_id] = verifier
return IO.NodeOutput(verifier_id) return IO.NodeOutput(verifier_id)
class InferenceScalingNode(IO.ComfyNode): class InferenceScalingNode(IO.ComfyNode):
""" """
Main inference scaling node that wraps KSampler with quality monitoring. Main inference scaling node that wraps KSampler with quality monitoring.
This node: This node:
1. Wraps the standard sampling process 1. Wraps the standard sampling process
2. Periodically decodes latents to images using VAE during sampling 2. Periodically decodes latents to images using VAE during sampling
@ -293,53 +292,53 @@ class InferenceScalingNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
""" """
Execute inference scaling sampling. Execute inference scaling sampling.
This wraps the standard sampling process and adds quality monitoring. This wraps the standard sampling process and adds quality monitoring.
""" """
# Get verifier from registry # Get verifier from registry
import comfy_api_nodes.nodes_inference_scaling as mod import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry: if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.") raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
verifier = mod._verifier_registry[verifier_id] verifier = mod._verifier_registry[verifier_id]
# Prepare sampling parameters # Prepare sampling parameters
device = model_management.get_torch_device() model_management.get_torch_device()
# Extract latent_image tensor from dict # Extract latent_image tensor from dict
latent_image_tensor = latent_image["samples"] latent_image_tensor = latent_image["samples"]
# Fix empty latent channels if needed # Fix empty latent channels if needed
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor) latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
# Prepare noise from latent_image and seed (same as common_ksampler does) # Prepare noise from latent_image and seed (same as common_ksampler does)
batch_inds = latent_image.get("batch_index", None) batch_inds = latent_image.get("batch_index", None)
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds) noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
# Get noise_mask if present # Get noise_mask if present
noise_mask = latent_image.get("noise_mask", None) noise_mask = latent_image.get("noise_mask", None)
# Track quality during sampling # Track quality during sampling
quality_history = [] quality_history = []
current_cfg = cfg current_cfg = cfg
def quality_check_callback(callback_dict: dict): def quality_check_callback(callback_dict: dict):
""" """
Callback function called during sampling steps. Callback function called during sampling steps.
Decodes latents, checks quality, and adjusts parameters. Decodes latents, checks quality, and adjusts parameters.
Args: Args:
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised' callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
""" """
nonlocal current_cfg, quality_history nonlocal current_cfg, quality_history
step = callback_dict['i'] step = callback_dict['i']
denoised = callback_dict['denoised'] # This is the predicted x0 denoised = callback_dict['denoised'] # This is the predicted x0
# Only check at specified intervals # Only check at specified intervals
if check_interval > 0 and step % check_interval != 0: if check_interval > 0 and step % check_interval != 0:
return return
try: try:
# Decode latent to image using VAE # Decode latent to image using VAE
# denoised is the predicted x0 in latent space [B, C, H, W] # denoised is the predicted x0 in latent space [B, C, H, W]
@ -347,26 +346,26 @@ class InferenceScalingNode(IO.ComfyNode):
# Ensure denoised is a proper tensor # Ensure denoised is a proper tensor
if not isinstance(denoised, torch.Tensor): if not isinstance(denoised, torch.Tensor):
return return
# VAE.decode expects [B, C, H, W] format and handles device management # VAE.decode expects [B, C, H, W] format and handles device management
# It returns [B, H, W, C] in range [0, 1] # It returns [B, H, W, C] in range [0, 1]
decoded = vae.decode(denoised) decoded = vae.decode(denoised)
# Ensure decoded is in correct format [B, H, W, C] # Ensure decoded is in correct format [B, H, W, C]
if decoded is None: if decoded is None:
return return
# Move to CPU for quality checking to avoid GPU memory issues # Move to CPU for quality checking to avoid GPU memory issues
decoded_cpu = decoded.cpu() decoded_cpu = decoded.cpu()
# Check quality # Check quality
is_acceptable, quality_score = verifier.verify(decoded_cpu) is_acceptable, quality_score = verifier.verify(decoded_cpu)
quality_history.append((step, quality_score, is_acceptable)) quality_history.append((step, quality_score, is_acceptable))
# Log quality for debugging # Log quality for debugging
import logging import logging
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}") logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
# Note: CFG adjustment here won't affect current sampling run # Note: CFG adjustment here won't affect current sampling run
# as CFG is set at the start. This is for future reference/logging. # as CFG is set at the start. This is for future reference/logging.
if not is_acceptable and quality_score < quality_threshold: if not is_acceptable and quality_score < quality_threshold:
@ -375,15 +374,14 @@ class InferenceScalingNode(IO.ComfyNode):
elif is_acceptable and quality_score > quality_threshold * 1.2: elif is_acceptable and quality_score > quality_threshold * 1.2:
# Quality is good # Quality is good
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}") logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
except Exception as e: except Exception as e:
# If decoding fails, continue sampling (don't break the generation) # If decoding fails, continue sampling (don't break the generation)
import logging import logging
import traceback
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}") logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
# Don't log full traceback in production to avoid spam # Don't log full traceback in production to avoid spam
# logging.debug(traceback.format_exc()) # logging.debug(traceback.format_exc())
# Perform sampling with quality monitoring # Perform sampling with quality monitoring
try: try:
# Use comfy.sample.sample which handles everything properly # Use comfy.sample.sample which handles everything properly
@ -407,13 +405,13 @@ class InferenceScalingNode(IO.ComfyNode):
disable_pbar=False, disable_pbar=False,
seed=seed, seed=seed,
) )
# Return the result as a latent dict (preserve other keys from input) # Return the result as a latent dict (preserve other keys from input)
result_latent = latent_image.copy() result_latent = latent_image.copy()
result_latent["samples"] = result_samples result_latent["samples"] = result_samples
return IO.NodeOutput(result_latent) return IO.NodeOutput(result_latent)
except Exception as e: except Exception as e:
import traceback import traceback
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}") raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
@ -423,7 +421,7 @@ class InferenceScalingExtension(ComfyExtension):
""" """
Extension class that registers inference scaling nodes. Extension class that registers inference scaling nodes.
""" """
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [