convert nodes_rebatch.py to V3 schema (#9945)

This commit is contained in:
Alexander Piskun 2025-09-27 00:10:49 +03:00 committed by GitHub
parent c4a46e943c
commit 76eb1d72c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,18 +1,25 @@
from typing_extensions import override
import torch import torch
class LatentRebatch: from comfy_api.latest import ComfyExtension, io
class LatentRebatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "latents": ("LATENT",), return io.Schema(
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RebatchLatents",
}} display_name="Rebatch Latents",
RETURN_TYPES = ("LATENT",) category="latent/batch",
INPUT_IS_LIST = True is_input_list=True,
OUTPUT_IS_LIST = (True, ) inputs=[
io.Latent.Input("latents"),
FUNCTION = "rebatch" io.Int.Input("batch_size", default=1, min=1, max=4096),
],
CATEGORY = "latent/batch" outputs=[
io.Latent.Output(is_output_list=True),
],
)
@staticmethod @staticmethod
def get_batch(latents, list_ind, offset): def get_batch(latents, list_ind, offset):
@ -53,7 +60,8 @@ class LatentRebatch:
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result return result
def rebatch(self, latents, batch_size): @classmethod
def execute(cls, latents, batch_size):
batch_size = batch_size[0] batch_size = batch_size[0]
output_list = [] output_list = []
@ -63,24 +71,24 @@ class LatentRebatch:
for i in range(len(latents)): for i in range(len(latents)):
# fetch new entry of list # fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i) #samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed) next_batch = cls.get_batch(latents, i, processed)
processed += len(next_batch[2]) processed += len(next_batch[2])
# set to current if current is None # set to current if current is None
if current_batch[0] is None: if current_batch[0] is None:
current_batch = next_batch current_batch = next_batch
# add previous to list if dimensions do not match # add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size) sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch current_batch = next_batch
# cat if everything checks out # cat if everything checks out
else: else:
current_batch = self.cat_batch(current_batch, next_batch) current_batch = cls.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size # add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size: if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size) sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
for i in range(num): for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
@ -89,7 +97,7 @@ class LatentRebatch:
#add remainder #add remainder
if current_batch[0] is not None: if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size) sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks #get rid of empty masks
@ -97,23 +105,27 @@ class LatentRebatch:
if s['noise_mask'].mean() == 1.0: if s['noise_mask'].mean() == 1.0:
del s['noise_mask'] del s['noise_mask']
return (output_list,) return io.NodeOutput(output_list)
class ImageRebatch: class ImageRebatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "images": ("IMAGE",), return io.Schema(
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), node_id="RebatchImages",
}} display_name="Rebatch Images",
RETURN_TYPES = ("IMAGE",) category="image/batch",
INPUT_IS_LIST = True is_input_list=True,
OUTPUT_IS_LIST = (True, ) inputs=[
io.Image.Input("images"),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Image.Output(is_output_list=True),
],
)
FUNCTION = "rebatch" @classmethod
def execute(cls, images, batch_size):
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0] batch_size = batch_size[0]
output_list = [] output_list = []
@ -125,14 +137,17 @@ class ImageRebatch:
for i in range(0, len(all_images), batch_size): for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,) return io.NodeOutput(output_list)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = { class RebatchExtension(ComfyExtension):
"RebatchLatents": "Rebatch Latents", @override
"RebatchImages": "Rebatch Images", async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
LatentRebatch,
ImageRebatch,
]
async def comfy_entrypoint() -> RebatchExtension:
return RebatchExtension()