From 041dbd6a8a241eccb8a35eddc80b78176f42b7f0 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 01:00:08 +0200 Subject: [PATCH] add nodes --- comfy/ldm/seedvr/vae.py | 10 +-- comfy_extras/nodes_seedvr.py | 116 +++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 comfy_extras/nodes_seedvr.py diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index eb74e9442..51c5b2578 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1051,10 +1051,9 @@ class VideoAutoencoderKL(nn.Module): out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock3D",), up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, + layers_per_block: int = 2, act_fn: str = "silu", - latent_channels: int = 4, + latent_channels: int = 16, norm_num_groups: int = 32, attention: bool = True, temporal_scale_num: int = 2, @@ -1062,12 +1061,13 @@ class VideoAutoencoderKL(nn.Module): gradient_checkpoint: bool = False, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "full", - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, *args, **kwargs, ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + block_out_channels = (128, 256, 512, 512) super().__init__() # pass init params to Encoder diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py new file mode 100644 index 000000000..60bd551dd --- /dev/null +++ b/comfy_extras/nodes_seedvr.py @@ -0,0 +1,116 @@ + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io, ui +import torch +import math +from einops import rearrange + +from torchvision.transforms import functional as TVF +from torchvision.transforms import Lambda, Normalize +from torchvision.transforms.functional import InterpolationMode + + +def area_resize(image, max_area): + + height, width = image.shape[-2:] + scale = math.sqrt(max_area / (height * width)) + + resized_height, resized_width = round(height * scale), round(width * scale) + + return TVF.resize( + image, + size=(resized_height, resized_width), + interpolation=InterpolationMode.BICUBIC, + ) + +def crop(image, factor): + height_factor, width_factor = factor + height, width = image.shape[-2:] + + cropped_height = height - (height % height_factor) + cropped_width = width - (width % width_factor) + + image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) + return image + +def cut_videos(videos): + t = videos.size(1) + if t == 1: + return videos + if t <= 4 : + padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 - ((t - 1) % (4)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4) == 0 + return videos + +class SeedVR2InputProcessing(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id = "SeedVR2InputProcessing", + category="image/video", + inputs = [ + io.Image.Input("images"), + io.Int.Input("resolution_height"), + io.Int.Input("resolution_width") + ], + outputs = [ + io.Image.Output("images") + ] + ) + + @classmethod + def execute(cls, images, resolution_height, resolution_width): + max_area = ((resolution_height * resolution_width)** 0.5) ** 2 + clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + normalize = Normalize(0.5, 0.5) + images = area_resize(images, max_area) + images = clip(images) + images = crop(images, (16, 16)) + images = normalize(images) + images = rearrange(images, "t c h w -> c t h w") + images = cut_videos(images) + return + +class SeedVR2Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Conditioning", + category="image/video", + inputs=[ + io.Conditioning.Input("text_positive_conditioning"), + io.Conditioning.Input("text_negative_conditioning"), + io.Conditioning.Input("vae_conditioning") + ], + outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")], + ) + + @classmethod + def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: + # TODO + pos_cond = text_positive_conditioning[0][0] + neg_cond = text_negative_conditioning[0][0] + + return io.NodeOutput() + +class SeedVRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SeedVR2Conditioning, + SeedVR2InputProcessing + ] + +async def comfy_entrypoint() -> SeedVRExtension: + return SeedVRExtension() \ No newline at end of file