ComfyUI/comfy_extras/nodes_seedvr.py
Yousef Rafat 041dbd6a8a add nodes
2025-12-07 01:00:08 +02:00

116 lines
3.6 KiB
Python

from typing_extensions import override
from comfy_api.latest import ComfyExtension, io, ui
import torch
import math
from einops import rearrange
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda, Normalize
from torchvision.transforms.functional import InterpolationMode
def area_resize(image, max_area):
height, width = image.shape[-2:]
scale = math.sqrt(max_area / (height * width))
resized_height, resized_width = round(height * scale), round(width * scale)
return TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=InterpolationMode.BICUBIC,
)
def crop(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
cropped_height = height - (height % height_factor)
cropped_width = width - (width % width_factor)
image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width))
return image
def cut_videos(videos):
t = videos.size(1)
if t == 1:
return videos
if t <= 4 :
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 - ((t - 1) % (4))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
return videos
class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id = "SeedVR2InputProcessing",
category="image/video",
inputs = [
io.Image.Input("images"),
io.Int.Input("resolution_height"),
io.Int.Input("resolution_width")
],
outputs = [
io.Image.Output("images")
]
)
@classmethod
def execute(cls, images, resolution_height, resolution_width):
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5)
images = area_resize(images, max_area)
images = clip(images)
images = crop(images, (16, 16))
images = normalize(images)
images = rearrange(images, "t c h w -> c t h w")
images = cut_videos(images)
return
class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
category="image/video",
inputs=[
io.Conditioning.Input("text_positive_conditioning"),
io.Conditioning.Input("text_negative_conditioning"),
io.Conditioning.Input("vae_conditioning")
],
outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")],
)
@classmethod
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
# TODO
pos_cond = text_positive_conditioning[0][0]
neg_cond = text_negative_conditioning[0][0]
return io.NodeOutput()
class SeedVRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SeedVR2Conditioning,
SeedVR2InputProcessing
]
async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension()