convert nodes_rebatch.py to V3 schema (#9945)

This commit is contained in:
Alexander Piskun 2025-09-27 00:10:49 +03:00 committed by GitHub
parent c4a46e943c
commit 76eb1d72c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,18 +1,25 @@
from typing_extensions import override
import torch
class LatentRebatch:
from comfy_api.latest import ComfyExtension, io
class LatentRebatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
def define_schema(cls):
return io.Schema(
node_id="RebatchLatents",
display_name="Rebatch Latents",
category="latent/batch",
is_input_list=True,
inputs=[
io.Latent.Input("latents"),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(is_output_list=True),
],
)
@staticmethod
def get_batch(latents, list_ind, offset):
@ -53,7 +60,8 @@ class LatentRebatch:
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
@classmethod
def execute(cls, latents, batch_size):
batch_size = batch_size[0]
output_list = []
@ -63,24 +71,24 @@ class LatentRebatch:
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
next_batch = cls.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
current_batch = cls.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
@ -89,7 +97,7 @@ class LatentRebatch:
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
@ -97,23 +105,27 @@ class LatentRebatch:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
return io.NodeOutput(output_list)
class ImageRebatch:
class ImageRebatch(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
def define_schema(cls):
return io.Schema(
node_id="RebatchImages",
display_name="Rebatch Images",
category="image/batch",
is_input_list=True,
inputs=[
io.Image.Input("images"),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Image.Output(is_output_list=True),
],
)
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
@classmethod
def execute(cls, images, batch_size):
batch_size = batch_size[0]
output_list = []
@ -125,14 +137,17 @@ class ImageRebatch:
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
return io.NodeOutput(output_list)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
"RebatchImages": "Rebatch Images",
}
class RebatchExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LatentRebatch,
ImageRebatch,
]
async def comfy_entrypoint() -> RebatchExtension:
return RebatchExtension()