mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-31 01:00:53 +08:00
Add Batch Images/Masks/Latents node
This commit is contained in:
parent
9cbfb96bf7
commit
8590bcf48a
@ -9,6 +9,7 @@ from typing import TypedDict, Literal
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy_extras.nodes_latent import reshape_latent_to
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from nodes import MAX_RESOLUTION
|
from nodes import MAX_RESOLUTION
|
||||||
@ -443,6 +444,94 @@ class ResizeImageMaskNode(io.ComfyNode):
|
|||||||
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
|
||||||
raise ValueError(f"Unsupported resize type: {selected_type}")
|
raise ValueError(f"Unsupported resize type: {selected_type}")
|
||||||
|
|
||||||
|
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(images) == 0:
|
||||||
|
return None
|
||||||
|
# first, get the max channels count
|
||||||
|
max_channels = max(image.shape[-1] for image in images)
|
||||||
|
# then, pad all images to have the same channels count
|
||||||
|
padded_images: list[torch.Tensor] = []
|
||||||
|
for image in images:
|
||||||
|
if image.shape[-1] < max_channels:
|
||||||
|
padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0))
|
||||||
|
else:
|
||||||
|
padded_images.append(image)
|
||||||
|
# resize all images to be the same size as the first image
|
||||||
|
resized_images: list[torch.Tensor] = []
|
||||||
|
first_image_shape = padded_images[0].shape
|
||||||
|
for image in padded_images:
|
||||||
|
if image.shape[1:] != first_image_shape[1:]:
|
||||||
|
resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1))
|
||||||
|
else:
|
||||||
|
resized_images.append(image)
|
||||||
|
# batch the images in the format [b, h, w, c]
|
||||||
|
return torch.cat(resized_images, dim=0)
|
||||||
|
|
||||||
|
def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None:
|
||||||
|
if len(masks) == 0:
|
||||||
|
return None
|
||||||
|
# resize all masks to be the same size as the first mask
|
||||||
|
resized_masks: list[torch.Tensor] = []
|
||||||
|
first_mask_shape = masks[0].shape
|
||||||
|
for mask in masks:
|
||||||
|
if mask.shape[1:] != first_mask_shape[1:]:
|
||||||
|
mask = init_image_mask_input(mask, is_type_image=False)
|
||||||
|
mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center")
|
||||||
|
resized_masks.append(finalize_image_mask_input(mask, is_type_image=False))
|
||||||
|
else:
|
||||||
|
resized_masks.append(mask)
|
||||||
|
# batch the masks in the format [b, h, w]
|
||||||
|
return torch.cat(resized_masks, dim=0)
|
||||||
|
|
||||||
|
def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None:
|
||||||
|
if len(latents) == 0:
|
||||||
|
return None
|
||||||
|
samples_out = latents[0].copy()
|
||||||
|
samples_out["batch_index"] = []
|
||||||
|
first_samples = latents[0]["samples"]
|
||||||
|
tensors: list[torch.Tensor] = []
|
||||||
|
for latent in latents:
|
||||||
|
# first, deal with latent tensors
|
||||||
|
tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False))
|
||||||
|
# next, deal with batch_index
|
||||||
|
samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])]))
|
||||||
|
samples_out["samples"] = torch.cat(tensors, dim=0)
|
||||||
|
return samples_out
|
||||||
|
|
||||||
|
class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("input", matchtype_template),
|
||||||
|
prefix="input", min=1, max=50)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="BatchImagesMasksLatentsNode",
|
||||||
|
display_name="Batch Images/Masks/Latents",
|
||||||
|
category="util",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("inputs", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(id=None, template=matchtype_template)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
batched = None
|
||||||
|
values = list(inputs.values())
|
||||||
|
# latents
|
||||||
|
if isinstance(values[0], dict):
|
||||||
|
batched = batch_latents(values)
|
||||||
|
# images
|
||||||
|
elif is_image(values[0]):
|
||||||
|
batched = batch_images(values)
|
||||||
|
# masks
|
||||||
|
else:
|
||||||
|
batched = batch_masks(values)
|
||||||
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -453,6 +542,7 @@ class PostProcessingExtension(ComfyExtension):
|
|||||||
Sharpen,
|
Sharpen,
|
||||||
ImageScaleToTotalPixels,
|
ImageScaleToTotalPixels,
|
||||||
ResizeImageMaskNode,
|
ResizeImageMaskNode,
|
||||||
|
BatchImagesMasksLatentsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> PostProcessingExtension:
|
async def comfy_entrypoint() -> PostProcessingExtension:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user