Compare commits

...

4 Commits

Author SHA1 Message Date
Alexander Piskun
e995d76a18
Merge 3bf3c4aa2b into 6592bffc60 2025-12-13 21:03:31 -08:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
bigcat88
3bf3c4aa2b
fix test 2025-12-09 09:58:15 +02:00
bigcat88
0548d9c2cc
converted nodes_images.py to V3 schema 2025-12-09 09:37:36 +02:00
7 changed files with 389 additions and 362 deletions

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad() @torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
""" """
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2 # Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac) if solver_type == "phi_1":
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise: if inject_noise:
segment_factor = (r - 1) * h * eta segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp() sde_noise = sde_noise * segment_factor.exp()

View File

@ -28,9 +28,8 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL from ._util import MESH, VOXEL, SVG as _SVG
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
class FolderType(str, Enum): class FolderType(str, Enum):
input = "input" input = "input"
@ -656,7 +655,7 @@ class Video(ComfyTypeIO):
@comfytype(io_type="SVG") @comfytype(io_type="SVG")
class SVG(ComfyTypeIO): class SVG(ComfyTypeIO):
Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 Type = _SVG
@comfytype(io_type="LORA_MODEL") @comfytype(io_type="LORA_MODEL")
class LoraModel(ComfyTypeIO): class LoraModel(ComfyTypeIO):

View File

@ -1,5 +1,6 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH from .geometry_types import VOXEL, MESH
from .image_types import SVG
__all__ = [ __all__ = [
# Utility Types # Utility Types
@ -8,4 +9,5 @@ __all__ = [
"VideoComponents", "VideoComponents",
"VOXEL", "VOXEL",
"MESH", "MESH",
"SVG",
] ]

View File

@ -0,0 +1,18 @@
from io import BytesIO
class SVG:
"""Stores SVG representations via a list of BytesIO objects."""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise: class Noise_EmptyNoise:
def __init__(self): def __init__(self):
self.seed = 0 self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative, SamplerDPMAdaptative,
SamplerER_SDE, SamplerER_SDE,
SamplerSASolver, SamplerSASolver,
SamplerSEEDS2,
SplitSigmas, SplitSigmas,
SplitSigmasDenoise, SplitSigmasDenoise,
FlipSigmas, FlipSigmas,

View File

@ -2,280 +2,231 @@ from __future__ import annotations
import nodes import nodes
import folder_paths import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json import json
import os import os
import re import re
from io import BytesIO
from inspect import cleandoc
import torch import torch
import comfy.utils import comfy.utils
from comfy.comfy_types import FileLocator, IO
from server import PromptServer from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop: class ImageCrop(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), node_id="ImageCrop",
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), display_name="Image Crop",
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), category="image/transform",
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), inputs=[
}} IO.Image.Input("image"),
RETURN_TYPES = ("IMAGE",) IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
FUNCTION = "crop" IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1) x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1) y = min(y, image.shape[1] - 1)
to_x = width + x to_x = width + x
to_y = height + y to_y = height + y
img = image[:,y:to_y, x:to_x, :] img = image[:,y:to_y, x:to_x, :]
return (img,) return IO.NodeOutput(img)
class RepeatImageBatch: crop = execute # TODO: remove
class RepeatImageBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RepeatImageBatch",
}} category="image/batch",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "repeat" IO.Image.Input("image"),
IO.Int.Input("amount", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, amount) -> IO.NodeOutput:
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1)) s = image.repeat((amount, 1,1,1))
return (s,) return IO.NodeOutput(s)
class ImageFromBatch: repeat = execute # TODO: remove
class ImageFromBatch(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), node_id="ImageFromBatch",
"length": ("INT", {"default": 1, "min": 1, "max": 4096}), category="image/batch",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "frombatch" IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/batch" @classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
def frombatch(self, image, batch_index, length):
s_in = image s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
length = min(s_in.shape[0] - batch_index, length) length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone() s = s_in[batch_index:batch_index + length].clone()
return (s,) return IO.NodeOutput(s)
frombatch = execute # TODO: remove
class ImageAddNoise: class ImageAddNoise(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": ("IMAGE",), return IO.Schema(
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), node_id="ImageAddNoise",
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), category="image",
}} inputs=[
RETURN_TYPES = ("IMAGE",) IO.Image.Input("image"),
FUNCTION = "repeat" IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image" @classmethod
def execute(cls, image, seed, strength) -> IO.NodeOutput:
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,) return IO.NodeOutput(s)
class SaveAnimatedWEBP: repeat = execute # TODO: remove
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = () class SaveAnimatedWEBP(IO.ComfyNode):
FUNCTION = "save_images" COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
OUTPUT_NODE = True
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results: list[FileLocator] = []
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": return IO.Schema(
{"images": ("IMAGE", ), node_id="SaveAnimatedWEBP",
"filename_prefix": ("STRING", {"default": "ComfyUI"}), category="image/animation",
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), inputs=[
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) IO.Image.Input("images"),
}, IO.String.Input("filename_prefix", default="ComfyUI"),
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
} IO.Boolean.Input("lossless", default=True),
IO.Int.Input("quality", default=80, min=0, max=100),
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
RETURN_TYPES = () @classmethod
FUNCTION = "save_images" def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
lossless=lossless,
quality=quality,
method=cls.COMPRESS_METHODS.get(method)
)
)
OUTPUT_NODE = True save_images = execute # TODO: remove
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
class SVG:
"""
Stores SVG representations via a list of BytesIO objects.
"""
def __init__(self, data: list[BytesIO]):
self.data = data
def combine(self, other: 'SVG') -> 'SVG':
return SVG(self.data + other.data)
@staticmethod
def combine_all(svgs: list['SVG']) -> 'SVG':
all_svgs_list: list[BytesIO] = []
for svg_item in svgs:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch: class SaveAnimatedPNG(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAnimatedPNG",
category="image/animation",
inputs=[
IO.Image.Input("images"),
IO.String.Input("filename_prefix", default="ComfyUI"),
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
IO.Int.Input("compress_level", default=4, min=0, max=9),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
images=images,
filename_prefix=filename_prefix,
cls=cls,
fps=fps,
compress_level=compress_level,
)
)
save_images = execute # TODO: remove
class ImageStitch(IO.ComfyNode):
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageStitch",
"image1": ("IMAGE",), display_name="Image Stitch",
"direction": (["right", "down", "left", "up"], {"default": "right"}), description="Stitches image2 to image1 in the specified direction.\n"
"match_image_size": ("BOOLEAN", {"default": True}), "If image2 is not provided, returns image1 unchanged.\n"
"spacing_width": ( "Optional spacing can be added between images.",
"INT", category="image/transform",
{"default": 0, "min": 0, "max": 1024, "step": 2}, inputs=[
), IO.Image.Input("image1"),
"spacing_color": ( IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
["white", "black", "red", "green", "blue"], IO.Boolean.Input("match_image_size", default=True),
{"default": "white"}, IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2),
), IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"),
}, IO.Image.Input("image2", optional=True),
"optional": { ],
"image2": ("IMAGE",), outputs=[IO.Image.Output()],
}, )
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "stitch" def execute(
CATEGORY = "image/transform" cls,
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1, image1,
direction, direction,
match_image_size, match_image_size,
spacing_width, spacing_width,
spacing_color, spacing_color,
image2=None, image2=None,
): ) -> IO.NodeOutput:
if image2 is None: if image2 is None:
return (image1,) return IO.NodeOutput(image1)
# Handle batch size differences # Handle batch size differences
if image1.shape[0] != image2.shape[0]: if image1.shape[0] != image2.shape[0]:
@ -412,36 +363,30 @@ Optional spacing can be added between images.
images.insert(1, spacing) images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1 concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),) return IO.NodeOutput(torch.cat(images, dim=concat_dim))
stitch = execute # TODO: remove
class ResizeAndPadImage(IO.ComfyNode):
class ResizeAndPadImage:
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ResizeAndPadImage",
"image": ("IMAGE",), category="image/transform",
"target_width": ("INT", { inputs=[
"default": 512, IO.Image.Input("image"),
"min": 1, IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"max": MAX_RESOLUTION, IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
"step": 1 IO.Combo.Input("padding_color", options=["white", "black"]),
}), IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]),
"target_height": ("INT", { ],
"default": 512, outputs=[IO.Image.Output()],
"min": 1, )
"max": MAX_RESOLUTION,
"step": 1
}),
"padding_color": (["white", "black"],),
"interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],),
}
}
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "resize_and_pad" def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
CATEGORY = "image/transform"
def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation):
batch_size, orig_height, orig_width, channels = image.shape batch_size, orig_height, orig_width, channels = image.shape
scale_w = target_width / orig_width scale_w = target_width / orig_width
@ -469,52 +414,47 @@ class ResizeAndPadImage:
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
output = padded.permute(0, 2, 3, 1) output = padded.permute(0, 2, 3, 1)
return (output,) return IO.NodeOutput(output)
class SaveSVGNode: resize_and_pad = execute # TODO: remove
"""
Save SVG files on disk.
"""
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
RETURN_TYPES = () class SaveSVGNode(IO.ComfyNode):
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "save_svg"
CATEGORY = "image/save" # Changed
OUTPUT_NODE = True
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="SaveSVGNode",
"svg": ("SVG",), # Changed description="Save SVG files on disk.",
"filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) category="image/save",
}, inputs=[
"hidden": { IO.SVG.Input("svg"),
"prompt": "PROMPT", IO.String.Input(
"extra_pnginfo": "EXTRA_PNGINFO" "filename_prefix",
} default="svg/ComfyUI",
} tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): @classmethod
filename_prefix += self.prefix_append def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = list() results: list[UI.SavedResult] = []
# Prepare metadata JSON # Prepare metadata JSON
metadata_dict = {} metadata_dict = {}
if prompt is not None: if cls.hidden.prompt is not None:
metadata_dict["prompt"] = prompt metadata_dict["prompt"] = cls.hidden.prompt
if extra_pnginfo is not None: if cls.hidden.extra_pnginfo is not None:
metadata_dict.update(extra_pnginfo) metadata_dict.update(cls.hidden.extra_pnginfo)
# Convert metadata to JSON string # Convert metadata to JSON string
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
for batch_number, svg_bytes in enumerate(svg.data): for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg" file = f"{filename_with_batch_num}_{counter:05}_.svg"
@ -544,57 +484,64 @@ class SaveSVGNode:
with open(os.path.join(full_output_folder, file), 'wb') as svg_file: with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
svg_file.write(svg_content.encode('utf-8')) svg_file.write(svg_content.encode('utf-8'))
results.append({ results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1 counter += 1
return { "ui": { "images": results } } return IO.NodeOutput(ui={"images": results})
class GetImageSize: save_svg = execute # TODO: remove
class GetImageSize(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="GetImageSize",
"image": (IO.IMAGE,), display_name="Get Image Size",
}, description="Returns width and height of the image, and passes it through unchanged.",
"hidden": { category="image",
"unique_id": "UNIQUE_ID", inputs=[
} IO.Image.Input("image"),
} ],
outputs=[
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
IO.Int.Output(display_name="batch_size"),
],
hidden=[IO.Hidden.unique_id],
)
RETURN_TYPES = (IO.INT, IO.INT, IO.INT) @classmethod
RETURN_NAMES = ("width", "height", "batch_size") def execute(cls, image) -> IO.NodeOutput:
FUNCTION = "get_size"
CATEGORY = "image"
DESCRIPTION = """Returns width and height of the image, and passes it through unchanged."""
def get_size(self, image, unique_id=None) -> tuple[int, int]:
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
batch_size = image.shape[0] batch_size = image.shape[0]
# Send progress text to display size on the node # Send progress text to display size on the node
if unique_id: if cls.hidden.unique_id:
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
return width, height, batch_size return IO.NodeOutput(width, height, batch_size)
get_size = execute # TODO: remove
class ImageRotate(IO.ComfyNode):
class ImageRotate:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), node_id="ImageRotate",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "rotate" IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, rotation) -> IO.NodeOutput:
def rotate(self, image, rotation):
rotate_by = 0 rotate_by = 0
if rotation.startswith("90"): if rotation.startswith("90"):
rotate_by = 1 rotate_by = 1
@ -604,41 +551,57 @@ class ImageRotate:
rotate_by = 3 rotate_by = 3
image = torch.rot90(image, k=rotate_by, dims=[2, 1]) image = torch.rot90(image, k=rotate_by, dims=[2, 1])
return (image,) return IO.NodeOutput(image)
rotate = execute # TODO: remove
class ImageFlip(IO.ComfyNode):
class ImageFlip:
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "image": (IO.IMAGE,), return IO.Schema(
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],), node_id="ImageFlip",
}} category="image/transform",
RETURN_TYPES = (IO.IMAGE,) inputs=[
FUNCTION = "flip" IO.Image.Input("image"),
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/transform" @classmethod
def execute(cls, image, flip_method) -> IO.NodeOutput:
def flip(self, image, flip_method):
if flip_method.startswith("x"): if flip_method.startswith("x"):
image = torch.flip(image, dims=[1]) image = torch.flip(image, dims=[1])
elif flip_method.startswith("y"): elif flip_method.startswith("y"):
image = torch.flip(image, dims=[2]) image = torch.flip(image, dims=[2])
return (image,) return IO.NodeOutput(image)
class ImageScaleToMaxDimension: flip = execute # TODO: remove
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
class ImageScaleToMaxDimension(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE",), return IO.Schema(
"upscale_method": (s.upscale_methods,), node_id="ImageScaleToMaxDimension",
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} category="image/upscaling",
RETURN_TYPES = ("IMAGE",) inputs=[
FUNCTION = "upscale" IO.Image.Input("image"),
IO.Combo.Input(
"upscale_method",
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
),
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1] height = image.shape[1]
width = image.shape[2] width = image.shape[2]
@ -655,20 +618,30 @@ class ImageScaleToMaxDimension:
samples = image.movedim(-1, 1) samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1) s = s.movedim(1, -1)
return (s,) return IO.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, class ImagesExtension(ComfyExtension):
"ImageAddNoise": ImageAddNoise, @override
"SaveAnimatedWEBP": SaveAnimatedWEBP, async def get_node_list(self) -> list[type[IO.ComfyNode]]:
"SaveAnimatedPNG": SaveAnimatedPNG, return [
"SaveSVGNode": SaveSVGNode, ImageCrop,
"ImageStitch": ImageStitch, RepeatImageBatch,
"ResizeAndPadImage": ResizeAndPadImage, ImageFromBatch,
"GetImageSize": GetImageSize, ImageAddNoise,
"ImageRotate": ImageRotate, SaveAnimatedWEBP,
"ImageFlip": ImageFlip, SaveAnimatedPNG,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension, SaveSVGNode,
} ImageStitch,
ResizeAndPadImage,
GetImageSize,
ImageRotate,
ImageFlip,
ImageScaleToMaxDimension,
]
async def comfy_entrypoint() -> ImagesExtension:
return ImagesExtension()

View File

@ -25,7 +25,7 @@ class TestImageStitch:
result = node.stitch(image1, "right", True, 0, "white", image2=None) result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1 assert len(result.result) == 1
assert torch.equal(result[0], image1) assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self): def test_basic_horizontal_stitch_right(self):