mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 23:02:49 +08:00
Add inference scaling nodes with quality monitoring during generation
- Implement VerifierSelectionNode for creating quality verifiers - Implement InferenceScalingNode that wraps KSampler with quality checks - Add VAE decoding during sampling steps for quality assessment - Include quality verification logic using variance and edge detection - Fix noise/latent handling to match ComfyUI patterns - Add comprehensive error handling and logging - Include documentation and debugging notes
This commit is contained in:
parent
f21f6b2212
commit
c254fa10d3
95
DEBUGGING_NOTES.md
Normal file
95
DEBUGGING_NOTES.md
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Debugging Notes - Inference Scaling Implementation
|
||||||
|
|
||||||
|
## Issues Found and Fixed
|
||||||
|
|
||||||
|
### 1. **Noise vs Latent Image Handling** ✅ FIXED
|
||||||
|
**Issue:** Was incorrectly extracting `samples` from `latent_image` dict and passing it as `noise`.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- Now properly extracts `latent_image_tensor` from dict
|
||||||
|
- Uses `comfy.sample.prepare_noise()` to generate noise from latent_image and seed (same as `common_ksampler`)
|
||||||
|
- Passes both `noise` and `latent_image_tensor` to `comfy.sample.sample()`
|
||||||
|
|
||||||
|
### 2. **Sampler Usage** ✅ FIXED
|
||||||
|
**Issue:** Was creating `KSampler` instance but should use `comfy.sample.sample()` directly.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- Now uses `comfy.sample.sample()` which handles all the setup properly
|
||||||
|
- This matches how `common_ksampler` works in `nodes.py`
|
||||||
|
|
||||||
|
### 3. **VAE Decoding** ✅ FIXED
|
||||||
|
**Issue:** Needed to ensure proper tensor format and device handling.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- VAE.decode expects `[B, C, H, W]` tensor and returns `[B, H, W, C]`
|
||||||
|
- Added proper error handling and tensor validation
|
||||||
|
- Moves decoded images to CPU for quality checking to avoid GPU memory issues
|
||||||
|
|
||||||
|
### 4. **Callback Scope Issues** ✅ FIXED
|
||||||
|
**Issue:** Variables like `steps` not accessible in callback closure.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- Removed reference to `steps` in logging (not needed)
|
||||||
|
- Added `nonlocal` declarations for variables modified in callback
|
||||||
|
- Added check for `check_interval > 0` to avoid division by zero
|
||||||
|
|
||||||
|
### 5. **Error Handling** ✅ IMPROVED
|
||||||
|
**Issue:** Errors could break the sampling process.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- Added comprehensive try/except blocks
|
||||||
|
- Quality check failures don't stop sampling
|
||||||
|
- Better logging with appropriate log levels
|
||||||
|
- Traceback only in debug mode to avoid spam
|
||||||
|
|
||||||
|
### 6. **Result Format** ✅ FIXED
|
||||||
|
**Issue:** Need to preserve all keys from input latent dict.
|
||||||
|
|
||||||
|
**Fix:**
|
||||||
|
- Uses `latent_image.copy()` to preserve all keys (batch_index, noise_mask, etc.)
|
||||||
|
- Only updates `samples` key with result
|
||||||
|
|
||||||
|
## Remaining Limitations
|
||||||
|
|
||||||
|
### CFG Adjustment
|
||||||
|
- **Status:** Cannot be dynamically adjusted during sampling
|
||||||
|
- **Reason:** CFG is set at the start of sampling and used throughout
|
||||||
|
- **Workaround:** Quality checking still works and provides valuable feedback
|
||||||
|
- **Future:** Could implement adaptive step count or early stopping
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
- **VAE Decoding:** Adds overhead during sampling (decodes at intervals)
|
||||||
|
- **Mitigation:** Only checks at specified intervals (default: every 5 steps)
|
||||||
|
- **Future:** Could optimize by using TAESD for faster preview decoding
|
||||||
|
|
||||||
|
## Testing Checklist
|
||||||
|
|
||||||
|
- [x] Syntax check passes (`py_compile`)
|
||||||
|
- [x] No linter errors
|
||||||
|
- [ ] Import test (requires Python 3.8+ for `get_origin`)
|
||||||
|
- [ ] Runtime test with actual ComfyUI
|
||||||
|
- [ ] Test with different samplers/schedulers
|
||||||
|
- [ ] Test with different VAE models
|
||||||
|
- [ ] Test quality verification logic
|
||||||
|
- [ ] Test error handling (invalid verifier_id, etc.)
|
||||||
|
|
||||||
|
## Code Quality Improvements Made
|
||||||
|
|
||||||
|
1. **Better Error Messages:** More descriptive exceptions with tracebacks
|
||||||
|
2. **Logging:** Added info/warning logs for quality checks
|
||||||
|
3. **Type Safety:** Added isinstance checks for tensors
|
||||||
|
4. **Memory Management:** Moves decoded images to CPU
|
||||||
|
5. **Code Organization:** Follows ComfyUI patterns (`common_ksampler` style)
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
- `comfy_api_nodes/nodes_inference_scaling.py` - Main implementation
|
||||||
|
- `DEBUGGING_NOTES.md` - This file
|
||||||
|
|
||||||
|
## Next Steps for Full Testing
|
||||||
|
|
||||||
|
1. Test in actual ComfyUI environment
|
||||||
|
2. Verify callback is called correctly
|
||||||
|
3. Test VAE decoding with different models
|
||||||
|
4. Verify quality metrics are reasonable
|
||||||
|
5. Test edge cases (empty latents, different batch sizes, etc.)
|
||||||
434
HOW_TO_WRITE_NODES.md
Normal file
434
HOW_TO_WRITE_NODES.md
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
# How to Write Nodes in ComfyUI-Inference-Scaling
|
||||||
|
|
||||||
|
This guide will walk you through creating custom nodes for ComfyUI using the modern API system.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
In this ComfyUI project, nodes are created using the **ComfyAPI** system. Each node is a class that:
|
||||||
|
1. Inherits from `IO.ComfyNode`
|
||||||
|
2. Defines its schema (inputs, outputs, metadata) via `define_schema()`
|
||||||
|
3. Implements its execution logic via `execute()`
|
||||||
|
4. Is registered through a `ComfyExtension` class
|
||||||
|
|
||||||
|
## Basic Structure
|
||||||
|
|
||||||
|
### 1. Node Class
|
||||||
|
|
||||||
|
Every node is a class that inherits from `IO.ComfyNode`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
class MyCustomNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Description of what your node does.
|
||||||
|
This docstring will appear in the UI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="MyCustomNode",
|
||||||
|
display_name="My Custom Node",
|
||||||
|
category="mycategory/subcategory",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
# Define inputs here
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
# Define outputs here
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(cls, ...) -> IO.NodeOutput:
|
||||||
|
# Implementation here
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Schema Definition
|
||||||
|
|
||||||
|
The `define_schema()` method defines:
|
||||||
|
- **node_id**: Unique identifier (usually matches class name)
|
||||||
|
- **display_name**: Name shown in the UI
|
||||||
|
- **category**: Where it appears in the node menu (use `/` for subcategories)
|
||||||
|
- **description**: Tooltip/help text
|
||||||
|
- **inputs**: List of input definitions
|
||||||
|
- **outputs**: List of output definitions
|
||||||
|
- **hidden**: Optional hidden inputs (like auth tokens)
|
||||||
|
- **is_api_node**: Set to `True` for API nodes
|
||||||
|
|
||||||
|
### 3. Input Types
|
||||||
|
|
||||||
|
Common input types available:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# String input
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
default="",
|
||||||
|
multiline=True, # For longer text
|
||||||
|
tooltip="Help text shown on hover",
|
||||||
|
optional=True, # Makes it optional
|
||||||
|
)
|
||||||
|
|
||||||
|
# Integer input
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=100,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider, # or .number
|
||||||
|
control_after_generate=True, # Shows randomize button
|
||||||
|
tooltip="Random seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Float input
|
||||||
|
IO.Float.Input(
|
||||||
|
"strength",
|
||||||
|
default=0.5,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Strength value",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combo/Dropdown input
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["option1", "option2", "option3"],
|
||||||
|
default="option1",
|
||||||
|
tooltip="Select a model",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image input
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Input image",
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask input
|
||||||
|
IO.Mask.Input(
|
||||||
|
"mask",
|
||||||
|
tooltip="Mask for inpainting",
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio input
|
||||||
|
IO.Audio.Input(
|
||||||
|
"audio",
|
||||||
|
tooltip="Input audio",
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Video input
|
||||||
|
IO.Video.Input(
|
||||||
|
"video",
|
||||||
|
tooltip="Input video",
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Output Types
|
||||||
|
|
||||||
|
Common output types:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Image output
|
||||||
|
IO.Image.Output()
|
||||||
|
|
||||||
|
# Audio output
|
||||||
|
IO.Audio.Output()
|
||||||
|
|
||||||
|
# Video output
|
||||||
|
IO.Video.Output()
|
||||||
|
|
||||||
|
# String output
|
||||||
|
IO.String.Output()
|
||||||
|
|
||||||
|
# Integer output
|
||||||
|
IO.Int.Output()
|
||||||
|
|
||||||
|
# Float output
|
||||||
|
IO.Float.Output()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Execute Method
|
||||||
|
|
||||||
|
The `execute()` method is where your node's logic runs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
# Parameters match input names from define_schema
|
||||||
|
prompt: str,
|
||||||
|
seed: int = 0,
|
||||||
|
image: Optional[torch.Tensor] = None,
|
||||||
|
# ... other inputs
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Execute the node logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
seed: Random seed
|
||||||
|
image: Optional image tensor (shape: [B, H, W, C])
|
||||||
|
...
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IO.NodeOutput with the result
|
||||||
|
"""
|
||||||
|
# Your implementation here
|
||||||
|
|
||||||
|
# For image outputs:
|
||||||
|
result_tensor = ... # torch.Tensor with shape [B, H, W, C]
|
||||||
|
return IO.NodeOutput(result_tensor)
|
||||||
|
|
||||||
|
# For multiple outputs:
|
||||||
|
return IO.NodeOutput(image=result_image, metadata=result_metadata)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- The method is `async` - use `await` for async operations
|
||||||
|
- Input parameters match the names from `define_schema()`
|
||||||
|
- Image tensors have shape `[Batch, Height, Width, Channels]` (usually RGBA)
|
||||||
|
- Use `IO.NodeOutput()` to return results
|
||||||
|
|
||||||
|
### 6. Extension Registration
|
||||||
|
|
||||||
|
To register your nodes, create an extension class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
MyCustomNode,
|
||||||
|
AnotherNode,
|
||||||
|
# ... list all your nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> MyExtension:
|
||||||
|
"""
|
||||||
|
Entry point function that ComfyUI calls to load your extension.
|
||||||
|
"""
|
||||||
|
return MyExtension()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
Here's a complete example of a simple image processing node:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from io import BytesIO
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
from inspect import cleandoc
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy_api_nodes.util import validate_string
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBrightnessNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Adjusts the brightness of an input image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ImageBrightnessNode",
|
||||||
|
display_name="Image Brightness",
|
||||||
|
category="image/processing",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Input image to adjust",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"brightness",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Brightness multiplier (1.0 = no change)",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
brightness: float = 1.0,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Adjust image brightness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image tensor [B, H, W, C]
|
||||||
|
brightness: Brightness multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Brightness-adjusted image
|
||||||
|
"""
|
||||||
|
# Ensure we have a batch dimension
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Apply brightness adjustment
|
||||||
|
# Clamp values to [0, 1] range
|
||||||
|
adjusted = torch.clamp(image * brightness, 0.0, 1.0)
|
||||||
|
|
||||||
|
return IO.NodeOutput(adjusted)
|
||||||
|
|
||||||
|
|
||||||
|
class MyExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ImageBrightnessNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> MyExtension:
|
||||||
|
return MyExtension()
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Node Example
|
||||||
|
|
||||||
|
For API nodes (nodes that call external APIs), you'll typically:
|
||||||
|
|
||||||
|
1. Use utility functions from `comfy_api_nodes.util`:
|
||||||
|
- `sync_op()` - for synchronous API calls
|
||||||
|
- `poll_op()` - for polling async operations
|
||||||
|
- `validate_string()` - for input validation
|
||||||
|
- `tensor_to_bytesio()` - convert image tensors to bytes
|
||||||
|
- `bytesio_to_image_tensor()` - convert bytes to image tensors
|
||||||
|
|
||||||
|
2. Include hidden inputs for authentication:
|
||||||
|
```python
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Use `ApiEndpoint` for API calls:
|
||||||
|
```python
|
||||||
|
from comfy_api_nodes.util import ApiEndpoint, sync_op
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/api/endpoint", method="POST"),
|
||||||
|
response_model=YourResponseModel,
|
||||||
|
data=YourRequestModel(...),
|
||||||
|
files={...}, # Optional file uploads
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Organization
|
||||||
|
|
||||||
|
1. **Create your node file**: `comfy_api_nodes/nodes_yourname.py`
|
||||||
|
2. **Follow naming conventions**:
|
||||||
|
- Node classes: `YourNodeName` (PascalCase)
|
||||||
|
- Extension class: `YourExtension` (PascalCase)
|
||||||
|
- File: `nodes_yourname.py` (snake_case)
|
||||||
|
3. **Import required modules**:
|
||||||
|
- `from comfy_api.latest import IO, ComfyExtension`
|
||||||
|
- `from typing_extensions import override`
|
||||||
|
- `from inspect import cleandoc`
|
||||||
|
|
||||||
|
## Testing Your Node
|
||||||
|
|
||||||
|
1. **Start ComfyUI**:
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Check the node appears** in the node menu under your specified category
|
||||||
|
|
||||||
|
3. **Test the node** by:
|
||||||
|
- Adding it to a workflow
|
||||||
|
- Connecting inputs
|
||||||
|
- Executing the workflow
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Working with Images
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Image tensor shape: [Batch, Height, Width, Channels]
|
||||||
|
# Channels are usually RGBA (4 channels)
|
||||||
|
|
||||||
|
# Convert PIL Image to tensor
|
||||||
|
pil_img = Image.open("image.png").convert("RGBA")
|
||||||
|
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||||
|
tensor = torch.from_numpy(arr).unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# Convert tensor to PIL Image
|
||||||
|
tensor = image.squeeze(0).cpu() # Remove batch, move to CPU
|
||||||
|
image_np = (tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
pil_img = Image.fromarray(image_np)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Validation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from comfy_api_nodes.util import validate_string
|
||||||
|
|
||||||
|
# Validate string inputs
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
|
# Check optional inputs
|
||||||
|
if image is not None:
|
||||||
|
# Process image
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
```python
|
||||||
|
if some_condition:
|
||||||
|
raise Exception("Error message here")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
1. **Use type hints** - They help with IDE autocomplete and documentation
|
||||||
|
2. **Add tooltips** - Help users understand what each input does
|
||||||
|
3. **Use `cleandoc()`** - Cleans up docstrings for display
|
||||||
|
4. **Make inputs optional** when appropriate - Improves usability
|
||||||
|
5. **Follow existing patterns** - Look at `nodes_openai.py` or `nodes_stability.py` for examples
|
||||||
|
6. **Test thoroughly** - Especially edge cases (None inputs, empty strings, etc.)
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- Existing node examples: `comfy_api_nodes/nodes_*.py`
|
||||||
|
- ComfyAPI documentation: `comfy_api/latest/`
|
||||||
|
- Utility functions: `comfy_api_nodes/util/`
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Look at existing nodes for reference
|
||||||
|
2. Start with a simple node
|
||||||
|
3. Test incrementally
|
||||||
|
4. Add more features as needed
|
||||||
|
|
||||||
|
Happy node writing! 🎨
|
||||||
129
INFERENCE_SCALING_README.md
Normal file
129
INFERENCE_SCALING_README.md
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Inference Scaling Implementation
|
||||||
|
|
||||||
|
This document describes the inference scaling feature implementation for ComfyUI.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The inference scaling system monitors image quality during the generation loop (not just at the end) and can adjust sampling parameters based on quality metrics. It consists of two main nodes:
|
||||||
|
|
||||||
|
1. **VerifierSelectionNode** - Creates and configures an image quality verifier
|
||||||
|
2. **InferenceScalingNode** - Wraps the standard KSampler with quality monitoring
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### VerifierSelectionNode
|
||||||
|
|
||||||
|
This node allows users to select and configure a quality verifier:
|
||||||
|
|
||||||
|
- **Inputs:**
|
||||||
|
- `verifier_type`: Type of verifier ("simple" or "custom")
|
||||||
|
- `quality_threshold`: Minimum quality score threshold (0.0-1.0)
|
||||||
|
|
||||||
|
- **Outputs:**
|
||||||
|
- `verifier_id`: String identifier for the verifier (used by InferenceScalingNode)
|
||||||
|
|
||||||
|
### InferenceScalingNode
|
||||||
|
|
||||||
|
This node wraps the standard sampling process with quality monitoring:
|
||||||
|
|
||||||
|
- **Inputs:**
|
||||||
|
- `verifier_id`: Verifier identifier from VerifierSelectionNode
|
||||||
|
- `model`: The diffusion model
|
||||||
|
- `positive`: Positive conditioning
|
||||||
|
- `negative`: Negative conditioning
|
||||||
|
- `latent_image`: Initial latent image
|
||||||
|
- `vae`: VAE model for decoding latents during quality checks
|
||||||
|
- `seed`: Random seed
|
||||||
|
- `steps`: Number of sampling steps
|
||||||
|
- `cfg`: CFG scale
|
||||||
|
- `sampler_name`: Sampler algorithm name
|
||||||
|
- `scheduler`: Scheduler name
|
||||||
|
- `check_interval`: Check quality every N steps (default: 5)
|
||||||
|
- `quality_threshold`: Quality threshold for adjustments
|
||||||
|
- `scale_factor`: Factor to scale parameters when quality is low
|
||||||
|
|
||||||
|
- **Outputs:**
|
||||||
|
- `latent`: The denoised latent image
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **During Sampling:**
|
||||||
|
- The node creates a callback function that is called at each sampling step
|
||||||
|
- At intervals specified by `check_interval`, the callback:
|
||||||
|
- Decodes the current latent representation to an image using VAE
|
||||||
|
- Uses the verifier to assess image quality
|
||||||
|
- Logs quality metrics
|
||||||
|
- Can adjust parameters based on quality (currently limited - see Limitations)
|
||||||
|
|
||||||
|
2. **Quality Verification:**
|
||||||
|
- The `SimpleVerifier` class uses:
|
||||||
|
- Image variance (higher = more detail)
|
||||||
|
- Edge detection strength (measures structure)
|
||||||
|
- Combined into a quality score (0.0-1.0)
|
||||||
|
|
||||||
|
3. **Scaling Logic:**
|
||||||
|
- When quality is below threshold: Can increase CFG (though this is limited - see Limitations)
|
||||||
|
- Quality history is tracked for analysis
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
### CFG Adjustment
|
||||||
|
|
||||||
|
**Important:** Currently, CFG cannot be dynamically adjusted during sampling because:
|
||||||
|
- CFG is set at the beginning of sampling and used throughout
|
||||||
|
- The sampling process doesn't support mid-run parameter changes
|
||||||
|
|
||||||
|
**Workarounds:**
|
||||||
|
- Quality checking still works and provides valuable feedback
|
||||||
|
- Quality history can be used to inform future runs
|
||||||
|
- Early stopping could be implemented (not yet done)
|
||||||
|
|
||||||
|
### Future Improvements
|
||||||
|
|
||||||
|
1. **Dynamic Step Adjustment:** Implement adaptive step count based on quality
|
||||||
|
2. **Early Stopping:** Stop sampling early if quality is consistently good
|
||||||
|
3. **Better Verifiers:** Add more sophisticated quality metrics (e.g., perceptual metrics, CLIP-based scoring)
|
||||||
|
4. **Parameter Tuning:** Implement more sophisticated parameter adjustment strategies
|
||||||
|
|
||||||
|
## Usage Example
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Create VerifierSelectionNode
|
||||||
|
- Set verifier_type: "simple"
|
||||||
|
- Set quality_threshold: 0.5
|
||||||
|
- Connect output to InferenceScalingNode's verifier_id input
|
||||||
|
|
||||||
|
2. Create InferenceScalingNode
|
||||||
|
- Connect model, positive, negative, latent_image, vae
|
||||||
|
- Set steps: 20
|
||||||
|
- Set cfg: 7.0
|
||||||
|
- Set check_interval: 5 (check every 5 steps)
|
||||||
|
- Set quality_threshold: 0.5
|
||||||
|
- Set scale_factor: 1.2
|
||||||
|
- Connect verifier_id from VerifierSelectionNode
|
||||||
|
```
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
- `comfy_api_nodes/nodes_inference_scaling.py` - Main implementation
|
||||||
|
- `INFERENCE_SCALING_README.md` - This file
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
To test the implementation:
|
||||||
|
|
||||||
|
1. Start ComfyUI
|
||||||
|
2. Create a workflow with:
|
||||||
|
- Model loading
|
||||||
|
- Text encoding (positive/negative)
|
||||||
|
- Empty latent image
|
||||||
|
- VerifierSelectionNode
|
||||||
|
- InferenceScalingNode
|
||||||
|
- VAE decode (to see final result)
|
||||||
|
3. Run the workflow and observe quality checks in logs
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- VAE decoding during sampling adds computational overhead
|
||||||
|
- Quality checks are performed at intervals to balance performance and monitoring
|
||||||
|
- The verifier registry stores verifiers in memory (simple approach for now)
|
||||||
439
comfy_api_nodes/nodes_inference_scaling.py
Normal file
439
comfy_api_nodes/nodes_inference_scaling.py
Normal file
@ -0,0 +1,439 @@
|
|||||||
|
"""
|
||||||
|
Inference Scaling Nodes for ComfyUI
|
||||||
|
|
||||||
|
This module provides nodes for adaptive inference scaling during image generation.
|
||||||
|
The system monitors image quality during the generation loop and adjusts sampling
|
||||||
|
parameters dynamically based on quality metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Callable, Dict, Any
|
||||||
|
from typing_extensions import override
|
||||||
|
from inspect import cleandoc
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy.samplers import KSampler
|
||||||
|
from comfy.sd import VAE
|
||||||
|
from comfy import model_management
|
||||||
|
import comfy.sample
|
||||||
|
|
||||||
|
# Get sampler and scheduler options
|
||||||
|
SAMPLER_OPTIONS = KSampler.SAMPLERS
|
||||||
|
SCHEDULER_OPTIONS = KSampler.SCHEDULERS
|
||||||
|
|
||||||
|
|
||||||
|
class QualityVerifier:
|
||||||
|
"""
|
||||||
|
Base class for image quality verification.
|
||||||
|
Subclasses should implement specific quality metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = 0.5):
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||||
|
"""
|
||||||
|
Verify image quality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_acceptable, quality_score)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement verify method")
|
||||||
|
|
||||||
|
def get_quality_score(self, image: torch.Tensor) -> float:
|
||||||
|
"""
|
||||||
|
Get quality score without threshold check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image tensor [B, H, W, C] in range [0, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Quality score in range [0, 1]
|
||||||
|
"""
|
||||||
|
_, score = self.verify(image)
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleVerifier(QualityVerifier):
|
||||||
|
"""
|
||||||
|
Simple verifier using basic image statistics.
|
||||||
|
Uses variance and edge detection as quality indicators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def verify(self, image: torch.Tensor) -> tuple[bool, float]:
|
||||||
|
"""
|
||||||
|
Simple quality check based on image variance and structure.
|
||||||
|
"""
|
||||||
|
# Convert to grayscale for analysis
|
||||||
|
if image.shape[-1] == 4: # RGBA
|
||||||
|
gray = image[..., :3].mean(dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
gray = image.mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Calculate variance (higher variance = more detail)
|
||||||
|
variance = gray.var().item()
|
||||||
|
|
||||||
|
# Simple edge detection using Sobel-like operator
|
||||||
|
# Convert to numpy for easier processing
|
||||||
|
img_np = gray.squeeze().cpu().numpy()
|
||||||
|
if len(img_np.shape) == 3:
|
||||||
|
img_np = img_np[0] # Take first batch
|
||||||
|
|
||||||
|
# Calculate gradients
|
||||||
|
grad_x = np.gradient(img_np, axis=1)
|
||||||
|
grad_y = np.gradient(img_np, axis=0)
|
||||||
|
edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean()
|
||||||
|
|
||||||
|
# Combine metrics (normalize to [0, 1])
|
||||||
|
variance_score = min(variance * 10, 1.0) # Scale variance
|
||||||
|
edge_score = min(edge_strength * 5, 1.0) # Scale edge strength
|
||||||
|
|
||||||
|
quality_score = (variance_score * 0.6 + edge_score * 0.4)
|
||||||
|
|
||||||
|
is_acceptable = quality_score >= self.threshold
|
||||||
|
|
||||||
|
return is_acceptable, quality_score
|
||||||
|
|
||||||
|
|
||||||
|
class VerifierSelectionNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Node for selecting or creating an image quality verifier.
|
||||||
|
The verifier is used by InferenceScalingNode to judge image quality
|
||||||
|
during the generation process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="VerifierSelectionNode",
|
||||||
|
display_name="Verifier Selection",
|
||||||
|
category="inference_scaling",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"verifier_type",
|
||||||
|
options=["simple", "custom"],
|
||||||
|
default="simple",
|
||||||
|
tooltip="Type of quality verifier to use",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"quality_threshold",
|
||||||
|
default=0.5,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Minimum quality score threshold (0.0-1.0)",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.String.Output(), # Verifier identifier (for now, just a string)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
verifier_type: str = "simple",
|
||||||
|
quality_threshold: float = 0.5,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Create and return a verifier identifier.
|
||||||
|
|
||||||
|
In a real implementation, this would create and store the verifier
|
||||||
|
instance. For now, we return a serialized identifier.
|
||||||
|
"""
|
||||||
|
# Create verifier based on type
|
||||||
|
if verifier_type == "simple":
|
||||||
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||||
|
else:
|
||||||
|
verifier = SimpleVerifier(threshold=quality_threshold)
|
||||||
|
|
||||||
|
# Store verifier in a way that InferenceScalingNode can access it
|
||||||
|
# For now, we'll use a simple approach: store in a module-level dict
|
||||||
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||||
|
if not hasattr(mod, '_verifier_registry'):
|
||||||
|
mod._verifier_registry = {}
|
||||||
|
|
||||||
|
verifier_id = f"verifier_{id(verifier)}"
|
||||||
|
mod._verifier_registry[verifier_id] = verifier
|
||||||
|
|
||||||
|
return IO.NodeOutput(verifier_id)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceScalingNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
Main inference scaling node that wraps KSampler with quality monitoring.
|
||||||
|
|
||||||
|
This node:
|
||||||
|
1. Wraps the standard sampling process
|
||||||
|
2. Periodically decodes latents to images using VAE during sampling
|
||||||
|
3. Uses a verifier to judge image quality
|
||||||
|
4. Adjusts sampling parameters (steps, CFG) based on quality
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="InferenceScalingNode",
|
||||||
|
display_name="Inference Scaling",
|
||||||
|
category="inference_scaling",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"verifier_id",
|
||||||
|
default="",
|
||||||
|
tooltip="Verifier identifier from VerifierSelectionNode",
|
||||||
|
),
|
||||||
|
IO.Model.Input(
|
||||||
|
"model",
|
||||||
|
tooltip="Model to use for sampling",
|
||||||
|
),
|
||||||
|
IO.Conditioning.Input(
|
||||||
|
"positive",
|
||||||
|
tooltip="Positive conditioning",
|
||||||
|
),
|
||||||
|
IO.Conditioning.Input(
|
||||||
|
"negative",
|
||||||
|
tooltip="Negative conditioning",
|
||||||
|
),
|
||||||
|
IO.Latent.Input(
|
||||||
|
"latent_image",
|
||||||
|
tooltip="Initial latent image",
|
||||||
|
),
|
||||||
|
IO.Vae.Input(
|
||||||
|
"vae",
|
||||||
|
tooltip="VAE model for decoding latents during quality checks",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Random seed",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=20,
|
||||||
|
min=1,
|
||||||
|
max=1000,
|
||||||
|
tooltip="Number of sampling steps",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"cfg",
|
||||||
|
default=7.0,
|
||||||
|
min=0.0,
|
||||||
|
max=30.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="CFG scale",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"sampler_name",
|
||||||
|
options=SAMPLER_OPTIONS,
|
||||||
|
default="euler",
|
||||||
|
tooltip="Sampler name",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"scheduler",
|
||||||
|
options=SCHEDULER_OPTIONS,
|
||||||
|
default="normal",
|
||||||
|
tooltip="Scheduler name",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"check_interval",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=50,
|
||||||
|
tooltip="Check quality every N steps",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"quality_threshold",
|
||||||
|
default=0.5,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Quality threshold for early stopping or adjustment",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"scale_factor",
|
||||||
|
default=1.2,
|
||||||
|
min=0.5,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Factor to scale steps/CFG when quality is low",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Latent.Output(), # Latent output
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
verifier_id: str,
|
||||||
|
model: Any,
|
||||||
|
positive: Any,
|
||||||
|
negative: Any,
|
||||||
|
latent_image: dict,
|
||||||
|
vae: VAE,
|
||||||
|
seed: int = 0,
|
||||||
|
steps: int = 20,
|
||||||
|
cfg: float = 7.0,
|
||||||
|
sampler_name: str = "euler",
|
||||||
|
scheduler: str = "normal",
|
||||||
|
check_interval: int = 5,
|
||||||
|
quality_threshold: float = 0.5,
|
||||||
|
scale_factor: float = 1.2,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Execute inference scaling sampling.
|
||||||
|
|
||||||
|
This wraps the standard sampling process and adds quality monitoring.
|
||||||
|
"""
|
||||||
|
# Get verifier from registry
|
||||||
|
import comfy_api_nodes.nodes_inference_scaling as mod
|
||||||
|
if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry:
|
||||||
|
raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.")
|
||||||
|
|
||||||
|
verifier = mod._verifier_registry[verifier_id]
|
||||||
|
|
||||||
|
# Prepare sampling parameters
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
# Extract latent_image tensor from dict
|
||||||
|
latent_image_tensor = latent_image["samples"]
|
||||||
|
|
||||||
|
# Fix empty latent channels if needed
|
||||||
|
latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor)
|
||||||
|
|
||||||
|
# Prepare noise from latent_image and seed (same as common_ksampler does)
|
||||||
|
batch_inds = latent_image.get("batch_index", None)
|
||||||
|
noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds)
|
||||||
|
|
||||||
|
# Get noise_mask if present
|
||||||
|
noise_mask = latent_image.get("noise_mask", None)
|
||||||
|
|
||||||
|
# Track quality during sampling
|
||||||
|
quality_history = []
|
||||||
|
current_cfg = cfg
|
||||||
|
|
||||||
|
def quality_check_callback(callback_dict: dict):
|
||||||
|
"""
|
||||||
|
Callback function called during sampling steps.
|
||||||
|
Decodes latents, checks quality, and adjusts parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised'
|
||||||
|
"""
|
||||||
|
nonlocal current_cfg, quality_history
|
||||||
|
|
||||||
|
step = callback_dict['i']
|
||||||
|
denoised = callback_dict['denoised'] # This is the predicted x0
|
||||||
|
|
||||||
|
# Only check at specified intervals
|
||||||
|
if check_interval > 0 and step % check_interval != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode latent to image using VAE
|
||||||
|
# denoised is the predicted x0 in latent space [B, C, H, W]
|
||||||
|
with torch.no_grad():
|
||||||
|
# Ensure denoised is a proper tensor
|
||||||
|
if not isinstance(denoised, torch.Tensor):
|
||||||
|
return
|
||||||
|
|
||||||
|
# VAE.decode expects [B, C, H, W] format and handles device management
|
||||||
|
# It returns [B, H, W, C] in range [0, 1]
|
||||||
|
decoded = vae.decode(denoised)
|
||||||
|
|
||||||
|
# Ensure decoded is in correct format [B, H, W, C]
|
||||||
|
if decoded is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Move to CPU for quality checking to avoid GPU memory issues
|
||||||
|
decoded_cpu = decoded.cpu()
|
||||||
|
|
||||||
|
# Check quality
|
||||||
|
is_acceptable, quality_score = verifier.verify(decoded_cpu)
|
||||||
|
quality_history.append((step, quality_score, is_acceptable))
|
||||||
|
|
||||||
|
# Log quality for debugging
|
||||||
|
import logging
|
||||||
|
logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}")
|
||||||
|
|
||||||
|
# Note: CFG adjustment here won't affect current sampling run
|
||||||
|
# as CFG is set at the start. This is for future reference/logging.
|
||||||
|
if not is_acceptable and quality_score < quality_threshold:
|
||||||
|
# Quality is low - log for analysis
|
||||||
|
logging.warning(f"Inference Scaling: Low quality detected at step {step}: {quality_score:.3f} < {quality_threshold}")
|
||||||
|
elif is_acceptable and quality_score > quality_threshold * 1.2:
|
||||||
|
# Quality is good
|
||||||
|
logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If decoding fails, continue sampling (don't break the generation)
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}")
|
||||||
|
# Don't log full traceback in production to avoid spam
|
||||||
|
# logging.debug(traceback.format_exc())
|
||||||
|
|
||||||
|
# Perform sampling with quality monitoring
|
||||||
|
try:
|
||||||
|
# Use comfy.sample.sample which handles everything properly
|
||||||
|
result_samples = comfy.sample.sample(
|
||||||
|
model=model,
|
||||||
|
noise=noise,
|
||||||
|
steps=steps,
|
||||||
|
cfg=current_cfg,
|
||||||
|
sampler_name=sampler_name,
|
||||||
|
scheduler=scheduler,
|
||||||
|
positive=positive,
|
||||||
|
negative=negative,
|
||||||
|
latent_image=latent_image_tensor,
|
||||||
|
denoise=1.0,
|
||||||
|
disable_noise=False,
|
||||||
|
start_step=None,
|
||||||
|
last_step=None,
|
||||||
|
force_full_denoise=False,
|
||||||
|
noise_mask=noise_mask,
|
||||||
|
callback=quality_check_callback,
|
||||||
|
disable_pbar=False,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the result as a latent dict (preserve other keys from input)
|
||||||
|
result_latent = latent_image.copy()
|
||||||
|
result_latent["samples"] = result_samples
|
||||||
|
|
||||||
|
return IO.NodeOutput(result_latent)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceScalingExtension(ComfyExtension):
|
||||||
|
"""
|
||||||
|
Extension class that registers inference scaling nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
VerifierSelectionNode,
|
||||||
|
InferenceScalingNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> InferenceScalingExtension:
|
||||||
|
"""
|
||||||
|
Entry point function that ComfyUI calls to load the extension.
|
||||||
|
"""
|
||||||
|
return InferenceScalingExtension()
|
||||||
180
comfy_api_nodes/nodes_template_example.py
Normal file
180
comfy_api_nodes/nodes_template_example.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
Template example node file.
|
||||||
|
Copy this file and modify it to create your own custom nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
from inspect import cleandoc
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
This is an example node that demonstrates the basic structure.
|
||||||
|
Replace this docstring with a description of what your node does.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="ExampleNode",
|
||||||
|
display_name="Example Node",
|
||||||
|
category="example",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"text_input",
|
||||||
|
default="Hello, ComfyUI!",
|
||||||
|
multiline=False,
|
||||||
|
tooltip="A text input field",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"number_input",
|
||||||
|
default=42,
|
||||||
|
min=0,
|
||||||
|
max=100,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="A number input with slider",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"image_input",
|
||||||
|
tooltip="An optional image input",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.String.Output(),
|
||||||
|
IO.Int.Output(),
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
text_input: str,
|
||||||
|
number_input: int,
|
||||||
|
image_input: Optional[torch.Tensor] = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Execute the node logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_input: The text input value
|
||||||
|
number_input: The number input value
|
||||||
|
image_input: Optional image tensor [B, H, W, C]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NodeOutput with processed results
|
||||||
|
"""
|
||||||
|
# Process text
|
||||||
|
processed_text = f"Processed: {text_input}"
|
||||||
|
|
||||||
|
# Process number
|
||||||
|
processed_number = number_input * 2
|
||||||
|
|
||||||
|
# Process image if provided
|
||||||
|
if image_input is not None:
|
||||||
|
# Ensure batch dimension exists
|
||||||
|
if len(image_input.shape) == 3:
|
||||||
|
image_input = image_input.unsqueeze(0)
|
||||||
|
|
||||||
|
# Example: invert the image
|
||||||
|
processed_image = 1.0 - image_input
|
||||||
|
else:
|
||||||
|
# Create a default image if none provided
|
||||||
|
# Create a simple 256x256 RGBA image
|
||||||
|
default_image = torch.ones(1, 256, 256, 4) * 0.5
|
||||||
|
processed_image = default_image
|
||||||
|
|
||||||
|
# Return multiple outputs
|
||||||
|
return IO.NodeOutput(
|
||||||
|
string=processed_text,
|
||||||
|
int=processed_number,
|
||||||
|
image=processed_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImageNode(IO.ComfyNode):
|
||||||
|
"""
|
||||||
|
A simpler example with just image input/output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="SimpleImageNode",
|
||||||
|
display_name="Simple Image Node",
|
||||||
|
category="example/image",
|
||||||
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Input image",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"multiplier",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Brightness multiplier",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
"""
|
||||||
|
Multiply image brightness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image tensor [B, H, W, C]
|
||||||
|
multiplier: Brightness multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted image
|
||||||
|
"""
|
||||||
|
# Ensure batch dimension
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Apply multiplier and clamp to valid range
|
||||||
|
result = torch.clamp(image * multiplier, 0.0, 1.0)
|
||||||
|
|
||||||
|
return IO.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleExtension(ComfyExtension):
|
||||||
|
"""
|
||||||
|
Extension class that registers all your nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ExampleNode,
|
||||||
|
SimpleImageNode,
|
||||||
|
# Add more nodes here as you create them
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ExampleExtension:
|
||||||
|
"""
|
||||||
|
Entry point function that ComfyUI calls to load your extension.
|
||||||
|
This function name must be exactly 'comfy_entrypoint'.
|
||||||
|
"""
|
||||||
|
return ExampleExtension()
|
||||||
Loading…
Reference in New Issue
Block a user