From 46ba987361aa656ee279d782a7d80b17c2dd2c4f Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:44:57 +0300 Subject: [PATCH] Cleanup --- comfy/ldm/wan/model.py | 5 +- comfy/model_base.py | 4 +- comfy_extras/nodes_bernini.py | 96 ++++++++++++++++++----------------- 3 files changed, 51 insertions(+), 54 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 15689a428..394b71d08 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -570,10 +570,7 @@ class WanModel(torch.nn.Module): full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) - # In-context reference streams (e.g. Bernini source video / ref images): - # patch-embed each clean condition latent and append as extra tokens (their - # rope, with per-stream source_id, was appended to `freqs` in _forward). - # Inert when no context_latents are supplied. + # In-context reference streams (Bernini) context_latents = kwargs.get("context_latents", None) main_len = x.shape[1] if context_latents is not None: diff --git a/comfy/model_base.py b/comfy/model_base.py index 60c9055fc..88155b9ae 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1517,9 +1517,7 @@ class WAN21(BaseModel): if reference_latents is not None: out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) - # In-context reference conditioning (source video / reference images, - # e.g. Bernini): a list of clean latents appended as extra token streams - # with per-stream source_id rope. Inert when not supplied. + # In-context reference conditioning (Bernini) context_latents = kwargs.get("context_latents", None) if context_latents is not None: out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py index 34777f2d6..aab4dbd1b 100644 --- a/comfy_extras/nodes_bernini.py +++ b/comfy_extras/nodes_bernini.py @@ -1,8 +1,10 @@ import torch +from typing_extensions import override import comfy.model_management import comfy.utils import node_helpers +from comfy_api.latest import ComfyExtension, io def _resize_long_edge(image, max_size, stride=16): @@ -11,22 +13,17 @@ def _resize_long_edge(image, max_size, stride=16): scale = min(max_size / max(h, w), 1.0) nh = max(stride, round(h * scale / stride) * stride) nw = max(stride, round(w * scale / stride) * stride) - return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1) + return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1) -class BerniniConditioning: - """Bernini-R in-context conditioning for a Wan2.2-A14B model. +class BerniniConditioning(io.ComfyNode): + """Bernini in-context conditioning for a Wan2.2-A14B model. - Attaches the VAE-encoded source video / reference images to BOTH positive and - negative conditioning as ``context_latents`` -- an ordered list of clean - latent streams (source video first, then each reference image), which the Wan - model appends as extra tokens with per-stream source_id rope. With stock CFG - the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA) - it's a single forward over the full conditioning. + Attaches the VAE-encoded source video / reference images to the conditioning + an ordered list of clean latents (source video first, then each reference image), + which the Wan model appends as extra tokens with per-stream source_id rope. - The task is inferred from which inputs are connected -- no model input and no - task selector needed; the model loads as a normal Wan2.2 checkpoint via the - stock UNETLoader: + The task is inferred from which inputs are connected: (nothing) -> t2v source_video -> v2v source_video + ref images -> rv2v @@ -34,41 +31,43 @@ class BerniniConditioning: """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "vae": ("VAE",), - "width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": { - "source_video": ("IMAGE",), - "reference_images": ("IMAGE",), - "ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}), - }, - } + def define_schema(cls): + return io.Schema( + node_id="BerniniConditioning", + display_name="Bernini Conditioning", + category="conditioning/video_models", + description="Conditioning node for Bernini in-context video/image conditioning. Attach source video and/or reference images to the positive/negative conditioning, " + "which the Wan model will append as extra tokens with per-stream source_id rope.", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=8192, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (original task v2v or rv2v). Resized to width/height and trimmed to length."), + io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."), + io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "build" - CATEGORY = "conditioning/video_models" - - def build(self, positive, negative, vae, width, height, length, batch_size, - source_video=None, reference_images=None, ref_max_size=848): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, + source_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) # Ordered list of condition streams: source video (source_id 1) first, - # then each reference image (source_id 2, 3, ...). The model assigns the - # source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by - # which inputs are present. + # then each reference image (source_id 2, 3, ...), the model assigns the source_id from list order. context = [] if source_video is not None: - vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) context.append(vae.encode(vid[:, :, :, :3])) if reference_images is not None: @@ -80,13 +79,16 @@ class BerniniConditioning: positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) - return (positive, negative, {"samples": latent}) + return io.NodeOutput(positive, negative, {"samples": latent}) -NODE_CLASS_MAPPINGS = { - "BerniniConditioning": BerniniConditioning, -} +class BerniniExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + BerniniConditioning, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "BerniniConditioning": "Bernini Conditioning", -} + +async def comfy_entrypoint() -> BerniniExtension: + return BerniniExtension()