mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
convert nodes_rebatch.py to V3 schema (#9945)
This commit is contained in:
parent
c4a46e943c
commit
76eb1d72c3
@ -1,18 +1,25 @@
|
|||||||
|
from typing_extensions import override
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class LatentRebatch:
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class LatentRebatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "latents": ("LATENT",),
|
return io.Schema(
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
node_id="RebatchLatents",
|
||||||
}}
|
display_name="Rebatch Latents",
|
||||||
RETURN_TYPES = ("LATENT",)
|
category="latent/batch",
|
||||||
INPUT_IS_LIST = True
|
is_input_list=True,
|
||||||
OUTPUT_IS_LIST = (True, )
|
inputs=[
|
||||||
|
io.Latent.Input("latents"),
|
||||||
FUNCTION = "rebatch"
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
CATEGORY = "latent/batch"
|
outputs=[
|
||||||
|
io.Latent.Output(is_output_list=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_batch(latents, list_ind, offset):
|
def get_batch(latents, list_ind, offset):
|
||||||
@ -53,7 +60,8 @@ class LatentRebatch:
|
|||||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def rebatch(self, latents, batch_size):
|
@classmethod
|
||||||
|
def execute(cls, latents, batch_size):
|
||||||
batch_size = batch_size[0]
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
output_list = []
|
output_list = []
|
||||||
@ -63,24 +71,24 @@ class LatentRebatch:
|
|||||||
for i in range(len(latents)):
|
for i in range(len(latents)):
|
||||||
# fetch new entry of list
|
# fetch new entry of list
|
||||||
#samples, masks, indices = self.get_batch(latents, i)
|
#samples, masks, indices = self.get_batch(latents, i)
|
||||||
next_batch = self.get_batch(latents, i, processed)
|
next_batch = cls.get_batch(latents, i, processed)
|
||||||
processed += len(next_batch[2])
|
processed += len(next_batch[2])
|
||||||
# set to current if current is None
|
# set to current if current is None
|
||||||
if current_batch[0] is None:
|
if current_batch[0] is None:
|
||||||
current_batch = next_batch
|
current_batch = next_batch
|
||||||
# add previous to list if dimensions do not match
|
# add previous to list if dimensions do not match
|
||||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
current_batch = next_batch
|
current_batch = next_batch
|
||||||
# cat if everything checks out
|
# cat if everything checks out
|
||||||
else:
|
else:
|
||||||
current_batch = self.cat_batch(current_batch, next_batch)
|
current_batch = cls.cat_batch(current_batch, next_batch)
|
||||||
|
|
||||||
# add to list if dimensions gone above target batch size
|
# add to list if dimensions gone above target batch size
|
||||||
if current_batch[0].shape[0] > batch_size:
|
if current_batch[0].shape[0] > batch_size:
|
||||||
num = current_batch[0].shape[0] // batch_size
|
num = current_batch[0].shape[0] // batch_size
|
||||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
|
||||||
|
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||||
@ -89,7 +97,7 @@ class LatentRebatch:
|
|||||||
|
|
||||||
#add remainder
|
#add remainder
|
||||||
if current_batch[0] is not None:
|
if current_batch[0] is not None:
|
||||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
|
||||||
#get rid of empty masks
|
#get rid of empty masks
|
||||||
@ -97,23 +105,27 @@ class LatentRebatch:
|
|||||||
if s['noise_mask'].mean() == 1.0:
|
if s['noise_mask'].mean() == 1.0:
|
||||||
del s['noise_mask']
|
del s['noise_mask']
|
||||||
|
|
||||||
return (output_list,)
|
return io.NodeOutput(output_list)
|
||||||
|
|
||||||
class ImageRebatch:
|
class ImageRebatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "images": ("IMAGE",),
|
return io.Schema(
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
node_id="RebatchImages",
|
||||||
}}
|
display_name="Rebatch Images",
|
||||||
RETURN_TYPES = ("IMAGE",)
|
category="image/batch",
|
||||||
INPUT_IS_LIST = True
|
is_input_list=True,
|
||||||
OUTPUT_IS_LIST = (True, )
|
inputs=[
|
||||||
|
io.Image.Input("images"),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(is_output_list=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "rebatch"
|
@classmethod
|
||||||
|
def execute(cls, images, batch_size):
|
||||||
CATEGORY = "image/batch"
|
|
||||||
|
|
||||||
def rebatch(self, images, batch_size):
|
|
||||||
batch_size = batch_size[0]
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
output_list = []
|
output_list = []
|
||||||
@ -125,14 +137,17 @@ class ImageRebatch:
|
|||||||
for i in range(0, len(all_images), batch_size):
|
for i in range(0, len(all_images), batch_size):
|
||||||
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
||||||
|
|
||||||
return (output_list,)
|
return io.NodeOutput(output_list)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"RebatchLatents": LatentRebatch,
|
|
||||||
"RebatchImages": ImageRebatch,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class RebatchExtension(ComfyExtension):
|
||||||
"RebatchLatents": "Rebatch Latents",
|
@override
|
||||||
"RebatchImages": "Rebatch Images",
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
LatentRebatch,
|
||||||
|
ImageRebatch,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> RebatchExtension:
|
||||||
|
return RebatchExtension()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user