Add Batch Images/Masks/Latents node

This commit is contained in:
Jedrzej Kosinski 2025-12-26 18:33:02 -08:00
parent 9cbfb96bf7
commit 8590bcf48a

View File

@ -9,6 +9,7 @@ from typing import TypedDict, Literal
import comfy.utils
import comfy.model_management
from comfy_extras.nodes_latent import reshape_latent_to
import node_helpers
from comfy_api.latest import ComfyExtension, io
from nodes import MAX_RESOLUTION
@ -443,6 +444,94 @@ class ResizeImageMaskNode(io.ComfyNode):
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
if len(images) == 0:
return None
# first, get the max channels count
max_channels = max(image.shape[-1] for image in images)
# then, pad all images to have the same channels count
padded_images: list[torch.Tensor] = []
for image in images:
if image.shape[-1] < max_channels:
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
else:
padded_images.append(image)
# resize all images to be the same size as the first image
resized_images: list[torch.Tensor] = []
first_image_shape = padded_images[0].shape
for image in padded_images:
if image.shape[1:] != first_image_shape[1:]:
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
else:
resized_images.append(image)
# batch the images in the format [b, h, w, c]
return torch.cat(resized_images, dim=0)
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
if len(masks) == 0:
return None
# resize all masks to be the same size as the first mask
resized_masks: list[torch.Tensor] = []
first_mask_shape = masks[0].shape
for mask in masks:
if mask.shape[1:] != first_mask_shape[1:]:
mask = init_image_mask_input(mask, is_type_image=False)
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
else:
resized_masks.append(mask)
# batch the masks in the format [b, h, w]
return torch.cat(resized_masks, dim=0)
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
if len(latents) == 0:
return None
samples_out = latents[0].copy()
samples_out["batch_index"] = []
first_samples = latents[0]["samples"]
tensors: list[torch.Tensor] = []
for latent in latents:
# first, deal with latent tensors
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
# next, deal with batch_index
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
samples_out["samples"] = torch.cat(tensors, dim=0)
return samples_out
class BatchImagesMasksLatentsNode(io.ComfyNode):
@classmethod
def define_schema(cls):
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("input", matchtype_template),
prefix="input", min=1, max=50)
return io.Schema(
node_id="BatchImagesMasksLatentsNode",
display_name="Batch Images/Masks/Latents",
category="util",
inputs=[
io.Autogrow.Input("inputs", template=autogrow_template)
],
outputs=[
io.MatchType.Output(id=None, template=matchtype_template)
]
)
@classmethod
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
batched = None
values = list(inputs.values())
# latents
if isinstance(values[0], dict):
batched = batch_latents(values)
# images
elif is_image(values[0]):
batched = batch_images(values)
# masks
else:
batched = batch_masks(values)
return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -453,6 +542,7 @@ class PostProcessingExtension(ComfyExtension):
Sharpen,
ImageScaleToTotalPixels,
ResizeImageMaskNode,
BatchImagesMasksLatentsNode,
]
async def comfy_entrypoint() -> PostProcessingExtension: