diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index e439b18ef..2815c5ffc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -5,6 +5,7 @@ import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging +import math def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode): samples_out["samples"] = torch.narrow(s1, dim, index, amount) return io.NodeOutput(samples_out) +class LatentCutToBatch(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentCutToBatch", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["t", "x", "y"]), + io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, samples, dim, slice_size) -> io.NodeOutput: + samples_out = samples.copy() + + s1 = samples["samples"] + + if "x" in dim: + dim = s1.ndim - 1 + elif "y" in dim: + dim = s1.ndim - 2 + elif "t" in dim: + dim = s1.ndim - 3 + + if dim < 2: + return io.NodeOutput(samples) + + s = s1.movedim(dim, 1) + if s.shape[1] < slice_size: + slice_size = s.shape[1] + elif s.shape[1] % slice_size != 0: + s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size] + new_shape = [-1, slice_size] + list(s.shape[2:]) + samples_out["samples"] = s.reshape(new_shape).movedim(1, dim) + return io.NodeOutput(samples_out) + class LatentBatch(io.ComfyNode): @classmethod def define_schema(cls): @@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension): LatentInterpolate, LatentConcat, LatentCut, + LatentCutToBatch, LatentBatch, LatentBatchSeedBehavior, LatentApplyOperation,