This commit is contained in:
kijai 2026-06-01 21:44:57 +03:00
parent 886b2e5102
commit 46ba987361
3 changed files with 51 additions and 54 deletions

View File

@ -570,10 +570,7 @@ class WanModel(torch.nn.Module):
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1) x = torch.concat((full_ref, x), dim=1)
# In-context reference streams (e.g. Bernini source video / ref images): # In-context reference streams (Bernini)
# patch-embed each clean condition latent and append as extra tokens (their
# rope, with per-stream source_id, was appended to `freqs` in _forward).
# Inert when no context_latents are supplied.
context_latents = kwargs.get("context_latents", None) context_latents = kwargs.get("context_latents", None)
main_len = x.shape[1] main_len = x.shape[1]
if context_latents is not None: if context_latents is not None:

View File

@ -1517,9 +1517,7 @@ class WAN21(BaseModel):
if reference_latents is not None: if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
# In-context reference conditioning (source video / reference images, # In-context reference conditioning (Bernini)
# e.g. Bernini): a list of clean latents appended as extra token streams
# with per-stream source_id rope. Inert when not supplied.
context_latents = kwargs.get("context_latents", None) context_latents = kwargs.get("context_latents", None)
if context_latents is not None: if context_latents is not None:
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])

View File

@ -1,8 +1,10 @@
import torch import torch
from typing_extensions import override
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import node_helpers import node_helpers
from comfy_api.latest import ComfyExtension, io
def _resize_long_edge(image, max_size, stride=16): def _resize_long_edge(image, max_size, stride=16):
@ -11,22 +13,17 @@ def _resize_long_edge(image, max_size, stride=16):
scale = min(max_size / max(h, w), 1.0) scale = min(max_size / max(h, w), 1.0)
nh = max(stride, round(h * scale / stride) * stride) nh = max(stride, round(h * scale / stride) * stride)
nw = max(stride, round(w * scale / stride) * stride) nw = max(stride, round(w * scale / stride) * stride)
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1) return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
class BerniniConditioning: class BerniniConditioning(io.ComfyNode):
"""Bernini-R in-context conditioning for a Wan2.2-A14B model. """Bernini in-context conditioning for a Wan2.2-A14B model.
Attaches the VAE-encoded source video / reference images to BOTH positive and Attaches the VAE-encoded source video / reference images to the conditioning
negative conditioning as ``context_latents`` -- an ordered list of clean an ordered list of clean latents (source video first, then each reference image),
latent streams (source video first, then each reference image), which the Wan which the Wan model appends as extra tokens with per-stream source_id rope.
model appends as extra tokens with per-stream source_id rope. With stock CFG
the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA)
it's a single forward over the full conditioning.
The task is inferred from which inputs are connected -- no model input and no The task is inferred from which inputs are connected:
task selector needed; the model loads as a normal Wan2.2 checkpoint via the
stock UNETLoader:
(nothing) -> t2v (nothing) -> t2v
source_video -> v2v source_video -> v2v
source_video + ref images -> rv2v source_video + ref images -> rv2v
@ -34,41 +31,43 @@ class BerniniConditioning:
""" """
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="BerniniConditioning",
"positive": ("CONDITIONING",), display_name="Bernini Conditioning",
"negative": ("CONDITIONING",), category="conditioning/video_models",
"vae": ("VAE",), description="Conditioning node for Bernini in-context video/image conditioning. Attach source video and/or reference images to the positive/negative conditioning, "
"width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), "which the Wan model will append as extra tokens with per-stream source_id rope.",
"height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), inputs=[
"length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), io.Conditioning.Input("positive"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Conditioning.Input("negative"),
}, io.Vae.Input("vae"),
"optional": { io.Int.Input("width", default=832, min=16, max=8192, step=16),
"source_video": ("IMAGE",), io.Int.Input("height", default=480, min=16, max=8192, step=16),
"reference_images": ("IMAGE",), io.Int.Input("length", default=81, min=1, max=8192, step=4),
"ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}), io.Int.Input("batch_size", default=1, min=1, max=4096),
}, io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (original task v2v or rv2v). Resized to width/height and trimmed to length."),
} io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."),
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size,
FUNCTION = "build" source_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
def build(self, positive, negative, vae, width, height, length, batch_size,
source_video=None, reference_images=None, ref_max_size=848):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device()) device=comfy.model_management.intermediate_device())
# Ordered list of condition streams: source video (source_id 1) first, # Ordered list of condition streams: source video (source_id 1) first,
# then each reference image (source_id 2, 3, ...). The model assigns the # then each reference image (source_id 2, 3, ...), the model assigns the source_id from list order.
# source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by
# which inputs are present.
context = [] context = []
if source_video is not None: if source_video is not None:
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
context.append(vae.encode(vid[:, :, :, :3])) context.append(vae.encode(vid[:, :, :, :3]))
if reference_images is not None: if reference_images is not None:
@ -80,13 +79,16 @@ class BerniniConditioning:
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
return (positive, negative, {"samples": latent}) return io.NodeOutput(positive, negative, {"samples": latent})
NODE_CLASS_MAPPINGS = { class BerniniExtension(ComfyExtension):
"BerniniConditioning": BerniniConditioning, @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
BerniniConditioning,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"BerniniConditioning": "Bernini Conditioning", async def comfy_entrypoint() -> BerniniExtension:
} return BerniniExtension()