fix: resolve ruff linting errors in nodes_inference_scaling

- Remove unused imports: Optional, Callable, Dict, PIL.Image, traceback
- Remove trailing whitespace on blank lines (W293)
- Remove unused local variable: device
This commit is contained in:
unknown 2026-04-18 15:31:56 +09:00
parent 74442533da
commit c5ce3e1375

View File

@ -6,12 +6,11 @@ The system monitors image quality during the generation loop and adjusts samplin
parameters dynamically based on quality metrics.
"""
from typing import Optional, Callable, Dict, Any
from typing import Any
from typing_extensions import override
from inspect import cleandoc
import torch
import numpy as np
from PIL import Image
from comfy_api.latest import IO, ComfyExtension
from comfy.samplers import KSampler
@ -29,29 +28,29 @@ class QualityVerifier:
Base class for image quality verification.
Subclasses should implement specific quality metrics.
"""
def __init__(self, threshold: float = 0.5):
self.threshold = threshold
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
"""
Verify image quality.
Args:
image: Image tensor [B, H, W, C] in range [0, 1]
Returns:
Tuple of (is_acceptable, quality_score)
"""
raise NotImplementedError("Subclasses must implement verify method")
def get_quality_score(self, image: torch.Tensor) -> float:
"""
Get quality score without threshold check.
Args:
image: Image tensor [B, H, W, C] in range [0, 1]
Returns:
Quality score in range [0, 1]
"""
@ -64,7 +63,7 @@ class SimpleVerifier(QualityVerifier):
Simple verifier using basic image statistics.
Uses variance and edge detection as quality indicators.
"""
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
"""
Simple quality check based on image variance and structure.
@ -74,29 +73,29 @@ class SimpleVerifier(QualityVerifier):
gray = image[..., :3].mean(dim=-1, keepdim=True)
else:
gray = image.mean(dim=-1, keepdim=True)
# Calculate variance (higher variance = more detail)
variance = gray.var().item()
# Simple edge detection using Sobel-like operator
# Convert to numpy for easier processing
img_np = gray.squeeze().cpu().numpy()
if len(img_np.shape) == 3:
img_np = img_np[0] # Take first batch
# Calculate gradients
grad_x = np.gradient(img_np, axis=1)
grad_y = np.gradient(img_np, axis=0)
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
# Combine metrics (normalize to [0, 1])
variance_score = min(variance * 10, 1.0) # Scale variance
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
quality_score = (variance_score * 0.6 + edge_score * 0.4)
is_acceptable = quality_score >= self.threshold
return is_acceptable, quality_score
@ -143,7 +142,7 @@ class VerifierSelectionNode(IO.ComfyNode):
) -> IO.NodeOutput:
"""
Create and return a verifier identifier.
In a real implementation, this would create and store the verifier
instance. For now, we return a serialized identifier.
"""
@ -152,23 +151,23 @@ class VerifierSelectionNode(IO.ComfyNode):
verifier = SimpleVerifier(threshold=quality_threshold)
else:
verifier = SimpleVerifier(threshold=quality_threshold)
# Store verifier in a way that InferenceScalingNode can access it
# For now, we'll use a simple approach: store in a module-level dict
import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry'):
mod._verifier_registry = {}
verifier_id = f"verifier_{id(verifier)}"
mod._verifier_registry[verifier_id] = verifier
return IO.NodeOutput(verifier_id)
class InferenceScalingNode(IO.ComfyNode):
"""
Main inference scaling node that wraps KSampler with quality monitoring.
This node:
1. Wraps the standard sampling process
2. Periodically decodes latents to images using VAE during sampling
@ -293,53 +292,53 @@ class InferenceScalingNode(IO.ComfyNode):
) -> IO.NodeOutput:
"""
Execute inference scaling sampling.
This wraps the standard sampling process and adds quality monitoring.
"""
# Get verifier from registry
import comfy_api_nodes.nodes_inference_scaling as mod
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
verifier = mod._verifier_registry[verifier_id]
# Prepare sampling parameters
device = model_management.get_torch_device()
model_management.get_torch_device()
# Extract latent_image tensor from dict
latent_image_tensor = latent_image["samples"]
# Fix empty latent channels if needed
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
# Prepare noise from latent_image and seed (same as common_ksampler does)
batch_inds = latent_image.get("batch_index", None)
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
# Get noise_mask if present
noise_mask = latent_image.get("noise_mask", None)
# Track quality during sampling
quality_history = []
current_cfg = cfg
def quality_check_callback(callback_dict: dict):
"""
Callback function called during sampling steps.
Decodes latents, checks quality, and adjusts parameters.
Args:
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
"""
nonlocal current_cfg, quality_history
step = callback_dict['i']
denoised = callback_dict['denoised'] # This is the predicted x0
# Only check at specified intervals
if check_interval > 0 and step % check_interval != 0:
return
try:
# Decode latent to image using VAE
# denoised is the predicted x0 in latent space [B, C, H, W]
@ -347,26 +346,26 @@ class InferenceScalingNode(IO.ComfyNode):
# Ensure denoised is a proper tensor
if not isinstance(denoised, torch.Tensor):
return
# VAE.decode expects [B, C, H, W] format and handles device management
# It returns [B, H, W, C] in range [0, 1]
decoded = vae.decode(denoised)
# Ensure decoded is in correct format [B, H, W, C]
if decoded is None:
return
# Move to CPU for quality checking to avoid GPU memory issues
decoded_cpu = decoded.cpu()
# Check quality
is_acceptable, quality_score = verifier.verify(decoded_cpu)
quality_history.append((step, quality_score, is_acceptable))
# Log quality for debugging
import logging
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
# Note: CFG adjustment here won't affect current sampling run
# as CFG is set at the start. This is for future reference/logging.
if not is_acceptable and quality_score < quality_threshold:
@ -375,15 +374,14 @@ class InferenceScalingNode(IO.ComfyNode):
elif is_acceptable and quality_score > quality_threshold * 1.2:
# Quality is good
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
except Exception as e:
# If decoding fails, continue sampling (don't break the generation)
import logging
import traceback
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
# Don't log full traceback in production to avoid spam
# logging.debug(traceback.format_exc())
# Perform sampling with quality monitoring
try:
# Use comfy.sample.sample which handles everything properly
@ -407,13 +405,13 @@ class InferenceScalingNode(IO.ComfyNode):
disable_pbar=False,
seed=seed,
)
# Return the result as a latent dict (preserve other keys from input)
result_latent = latent_image.copy()
result_latent["samples"] = result_samples
return IO.NodeOutput(result_latent)
except Exception as e:
import traceback
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
@ -423,7 +421,7 @@ class InferenceScalingExtension(ComfyExtension):
"""
Extension class that registers inference scaling nodes.
"""
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [