mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 14:52:46 +08:00
fix: resolve ruff linting errors in nodes_inference_scaling
- Remove unused imports: Optional, Callable, Dict, PIL.Image, traceback - Remove trailing whitespace on blank lines (W293) - Remove unused local variable: device
This commit is contained in:
parent
74442533da
commit
c5ce3e1375
@ -6,12 +6,11 @@ The system monitors image quality during the generation loop and adjusts samplin
|
|||||||
parameters dynamically based on quality metrics.
|
parameters dynamically based on quality metrics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Callable, Dict, Any
|
from typing import Any
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy.samplers import KSampler
|
from comfy.samplers import KSampler
|
||||||
@ -29,29 +28,29 @@ class QualityVerifier:
|
|||||||
Base class for image quality verification.
|
Base class for image quality verification.
|
||||||
Subclasses should implement specific quality metrics.
|
Subclasses should implement specific quality metrics.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, threshold: float = 0.5):
|
def __init__(self, threshold: float = 0.5):
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
|
||||||
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||||
"""
|
"""
|
||||||
Verify image quality.
|
Verify image quality.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: Image tensor [B, H, W, C] in range [0, 1]
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (is_acceptable, quality_score)
|
Tuple of (is_acceptable, quality_score)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Subclasses must implement verify method")
|
raise NotImplementedError("Subclasses must implement verify method")
|
||||||
|
|
||||||
def get_quality_score(self, image: torch.Tensor) -> float:
|
def get_quality_score(self, image: torch.Tensor) -> float:
|
||||||
"""
|
"""
|
||||||
Get quality score without threshold check.
|
Get quality score without threshold check.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: Image tensor [B, H, W, C] in range [0, 1]
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Quality score in range [0, 1]
|
Quality score in range [0, 1]
|
||||||
"""
|
"""
|
||||||
@ -64,7 +63,7 @@ class SimpleVerifier(QualityVerifier):
|
|||||||
Simple verifier using basic image statistics.
|
Simple verifier using basic image statistics.
|
||||||
Uses variance and edge detection as quality indicators.
|
Uses variance and edge detection as quality indicators.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||||
"""
|
"""
|
||||||
Simple quality check based on image variance and structure.
|
Simple quality check based on image variance and structure.
|
||||||
@ -74,29 +73,29 @@ class SimpleVerifier(QualityVerifier):
|
|||||||
gray = image[..., :3].mean(dim=-1, keepdim=True)
|
gray = image[..., :3].mean(dim=-1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
gray = image.mean(dim=-1, keepdim=True)
|
gray = image.mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# Calculate variance (higher variance = more detail)
|
# Calculate variance (higher variance = more detail)
|
||||||
variance = gray.var().item()
|
variance = gray.var().item()
|
||||||
|
|
||||||
# Simple edge detection using Sobel-like operator
|
# Simple edge detection using Sobel-like operator
|
||||||
# Convert to numpy for easier processing
|
# Convert to numpy for easier processing
|
||||||
img_np = gray.squeeze().cpu().numpy()
|
img_np = gray.squeeze().cpu().numpy()
|
||||||
if len(img_np.shape) == 3:
|
if len(img_np.shape) == 3:
|
||||||
img_np = img_np[0] # Take first batch
|
img_np = img_np[0] # Take first batch
|
||||||
|
|
||||||
# Calculate gradients
|
# Calculate gradients
|
||||||
grad_x = np.gradient(img_np, axis=1)
|
grad_x = np.gradient(img_np, axis=1)
|
||||||
grad_y = np.gradient(img_np, axis=0)
|
grad_y = np.gradient(img_np, axis=0)
|
||||||
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
|
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
|
||||||
|
|
||||||
# Combine metrics (normalize to [0, 1])
|
# Combine metrics (normalize to [0, 1])
|
||||||
variance_score = min(variance * 10, 1.0) # Scale variance
|
variance_score = min(variance * 10, 1.0) # Scale variance
|
||||||
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
|
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
|
||||||
|
|
||||||
quality_score = (variance_score * 0.6 + edge_score * 0.4)
|
quality_score = (variance_score * 0.6 + edge_score * 0.4)
|
||||||
|
|
||||||
is_acceptable = quality_score >= self.threshold
|
is_acceptable = quality_score >= self.threshold
|
||||||
|
|
||||||
return is_acceptable, quality_score
|
return is_acceptable, quality_score
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +142,7 @@ class VerifierSelectionNode(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
"""
|
"""
|
||||||
Create and return a verifier identifier.
|
Create and return a verifier identifier.
|
||||||
|
|
||||||
In a real implementation, this would create and store the verifier
|
In a real implementation, this would create and store the verifier
|
||||||
instance. For now, we return a serialized identifier.
|
instance. For now, we return a serialized identifier.
|
||||||
"""
|
"""
|
||||||
@ -152,23 +151,23 @@ class VerifierSelectionNode(IO.ComfyNode):
|
|||||||
verifier = SimpleVerifier(threshold=quality_threshold)
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||||
else:
|
else:
|
||||||
verifier = SimpleVerifier(threshold=quality_threshold)
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||||
|
|
||||||
# Store verifier in a way that InferenceScalingNode can access it
|
# Store verifier in a way that InferenceScalingNode can access it
|
||||||
# For now, we'll use a simple approach: store in a module-level dict
|
# For now, we'll use a simple approach: store in a module-level dict
|
||||||
import comfy_api_nodes.nodes_inference_scaling as mod
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||||
if not hasattr(mod, '_verifier_registry'):
|
if not hasattr(mod, '_verifier_registry'):
|
||||||
mod._verifier_registry = {}
|
mod._verifier_registry = {}
|
||||||
|
|
||||||
verifier_id = f"verifier_{id(verifier)}"
|
verifier_id = f"verifier_{id(verifier)}"
|
||||||
mod._verifier_registry[verifier_id] = verifier
|
mod._verifier_registry[verifier_id] = verifier
|
||||||
|
|
||||||
return IO.NodeOutput(verifier_id)
|
return IO.NodeOutput(verifier_id)
|
||||||
|
|
||||||
|
|
||||||
class InferenceScalingNode(IO.ComfyNode):
|
class InferenceScalingNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Main inference scaling node that wraps KSampler with quality monitoring.
|
Main inference scaling node that wraps KSampler with quality monitoring.
|
||||||
|
|
||||||
This node:
|
This node:
|
||||||
1. Wraps the standard sampling process
|
1. Wraps the standard sampling process
|
||||||
2. Periodically decodes latents to images using VAE during sampling
|
2. Periodically decodes latents to images using VAE during sampling
|
||||||
@ -293,53 +292,53 @@ class InferenceScalingNode(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
"""
|
"""
|
||||||
Execute inference scaling sampling.
|
Execute inference scaling sampling.
|
||||||
|
|
||||||
This wraps the standard sampling process and adds quality monitoring.
|
This wraps the standard sampling process and adds quality monitoring.
|
||||||
"""
|
"""
|
||||||
# Get verifier from registry
|
# Get verifier from registry
|
||||||
import comfy_api_nodes.nodes_inference_scaling as mod
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||||
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
|
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
|
||||||
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
|
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
|
||||||
|
|
||||||
verifier = mod._verifier_registry[verifier_id]
|
verifier = mod._verifier_registry[verifier_id]
|
||||||
|
|
||||||
# Prepare sampling parameters
|
# Prepare sampling parameters
|
||||||
device = model_management.get_torch_device()
|
model_management.get_torch_device()
|
||||||
|
|
||||||
# Extract latent_image tensor from dict
|
# Extract latent_image tensor from dict
|
||||||
latent_image_tensor = latent_image["samples"]
|
latent_image_tensor = latent_image["samples"]
|
||||||
|
|
||||||
# Fix empty latent channels if needed
|
# Fix empty latent channels if needed
|
||||||
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
|
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
|
||||||
|
|
||||||
# Prepare noise from latent_image and seed (same as common_ksampler does)
|
# Prepare noise from latent_image and seed (same as common_ksampler does)
|
||||||
batch_inds = latent_image.get("batch_index", None)
|
batch_inds = latent_image.get("batch_index", None)
|
||||||
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
|
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
|
||||||
|
|
||||||
# Get noise_mask if present
|
# Get noise_mask if present
|
||||||
noise_mask = latent_image.get("noise_mask", None)
|
noise_mask = latent_image.get("noise_mask", None)
|
||||||
|
|
||||||
# Track quality during sampling
|
# Track quality during sampling
|
||||||
quality_history = []
|
quality_history = []
|
||||||
current_cfg = cfg
|
current_cfg = cfg
|
||||||
|
|
||||||
def quality_check_callback(callback_dict: dict):
|
def quality_check_callback(callback_dict: dict):
|
||||||
"""
|
"""
|
||||||
Callback function called during sampling steps.
|
Callback function called during sampling steps.
|
||||||
Decodes latents, checks quality, and adjusts parameters.
|
Decodes latents, checks quality, and adjusts parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
|
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
|
||||||
"""
|
"""
|
||||||
nonlocal current_cfg, quality_history
|
nonlocal current_cfg, quality_history
|
||||||
|
|
||||||
step = callback_dict['i']
|
step = callback_dict['i']
|
||||||
denoised = callback_dict['denoised'] # This is the predicted x0
|
denoised = callback_dict['denoised'] # This is the predicted x0
|
||||||
|
|
||||||
# Only check at specified intervals
|
# Only check at specified intervals
|
||||||
if check_interval > 0 and step % check_interval != 0:
|
if check_interval > 0 and step % check_interval != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Decode latent to image using VAE
|
# Decode latent to image using VAE
|
||||||
# denoised is the predicted x0 in latent space [B, C, H, W]
|
# denoised is the predicted x0 in latent space [B, C, H, W]
|
||||||
@ -347,26 +346,26 @@ class InferenceScalingNode(IO.ComfyNode):
|
|||||||
# Ensure denoised is a proper tensor
|
# Ensure denoised is a proper tensor
|
||||||
if not isinstance(denoised, torch.Tensor):
|
if not isinstance(denoised, torch.Tensor):
|
||||||
return
|
return
|
||||||
|
|
||||||
# VAE.decode expects [B, C, H, W] format and handles device management
|
# VAE.decode expects [B, C, H, W] format and handles device management
|
||||||
# It returns [B, H, W, C] in range [0, 1]
|
# It returns [B, H, W, C] in range [0, 1]
|
||||||
decoded = vae.decode(denoised)
|
decoded = vae.decode(denoised)
|
||||||
|
|
||||||
# Ensure decoded is in correct format [B, H, W, C]
|
# Ensure decoded is in correct format [B, H, W, C]
|
||||||
if decoded is None:
|
if decoded is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Move to CPU for quality checking to avoid GPU memory issues
|
# Move to CPU for quality checking to avoid GPU memory issues
|
||||||
decoded_cpu = decoded.cpu()
|
decoded_cpu = decoded.cpu()
|
||||||
|
|
||||||
# Check quality
|
# Check quality
|
||||||
is_acceptable, quality_score = verifier.verify(decoded_cpu)
|
is_acceptable, quality_score = verifier.verify(decoded_cpu)
|
||||||
quality_history.append((step, quality_score, is_acceptable))
|
quality_history.append((step, quality_score, is_acceptable))
|
||||||
|
|
||||||
# Log quality for debugging
|
# Log quality for debugging
|
||||||
import logging
|
import logging
|
||||||
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
|
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
|
||||||
|
|
||||||
# Note: CFG adjustment here won't affect current sampling run
|
# Note: CFG adjustment here won't affect current sampling run
|
||||||
# as CFG is set at the start. This is for future reference/logging.
|
# as CFG is set at the start. This is for future reference/logging.
|
||||||
if not is_acceptable and quality_score < quality_threshold:
|
if not is_acceptable and quality_score < quality_threshold:
|
||||||
@ -375,15 +374,14 @@ class InferenceScalingNode(IO.ComfyNode):
|
|||||||
elif is_acceptable and quality_score > quality_threshold * 1.2:
|
elif is_acceptable and quality_score > quality_threshold * 1.2:
|
||||||
# Quality is good
|
# Quality is good
|
||||||
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
|
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If decoding fails, continue sampling (don't break the generation)
|
# If decoding fails, continue sampling (don't break the generation)
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
|
||||||
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
|
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
|
||||||
# Don't log full traceback in production to avoid spam
|
# Don't log full traceback in production to avoid spam
|
||||||
# logging.debug(traceback.format_exc())
|
# logging.debug(traceback.format_exc())
|
||||||
|
|
||||||
# Perform sampling with quality monitoring
|
# Perform sampling with quality monitoring
|
||||||
try:
|
try:
|
||||||
# Use comfy.sample.sample which handles everything properly
|
# Use comfy.sample.sample which handles everything properly
|
||||||
@ -407,13 +405,13 @@ class InferenceScalingNode(IO.ComfyNode):
|
|||||||
disable_pbar=False,
|
disable_pbar=False,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the result as a latent dict (preserve other keys from input)
|
# Return the result as a latent dict (preserve other keys from input)
|
||||||
result_latent = latent_image.copy()
|
result_latent = latent_image.copy()
|
||||||
result_latent["samples"] = result_samples
|
result_latent["samples"] = result_samples
|
||||||
|
|
||||||
return IO.NodeOutput(result_latent)
|
return IO.NodeOutput(result_latent)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
|
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
|
||||||
@ -423,7 +421,7 @@ class InferenceScalingExtension(ComfyExtension):
|
|||||||
"""
|
"""
|
||||||
Extension class that registers inference scaling nodes.
|
Extension class that registers inference scaling nodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user