From 19475fdafc51804ea542bf5fb1dbd4e17664ffe4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 19:18:33 +0300 Subject: [PATCH] Initial commit --- comfy/ldm/wan/model.py | 125 +++++++++++++++++++++++++++++++++- comfy/model_base.py | 26 +++++++ comfy_extras/nodes_bernini.py | 95 ++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_bernini.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 70dfe7b16..070a00b39 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.math import apply_rope1, rope import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -1739,3 +1739,126 @@ class SCAILWanModel(WanModel): freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] + + +class BerniniWanModel(WanModel): + """Wan2.2-A14B fine-tune (ByteDance Bernini-R) with in-context conditioning. + + Source video / reference image latents are patch-embedded with the same + ``patch_embedding`` as the noisy target and concatenated as extra tokens + along the sequence. Each conditioning stream carries a ``source_id`` (target + = 0, conditions = 1, 2, ...) realised as an extra multiplicative rotary + factor composed into the spatial RoPE: spatial coordinates overlap across + streams, only the source_id separates them. Self-attention is full over the + concatenated sequence; the target tokens are sliced back out afterwards. + + The condition latents arrive as kwargs (``bernini_video_latent``, + ``bernini_image_latents``) from ``WAN22_Bernini.extra_conds``. + """ + + def _source_id_freqs(self, freqs, source_id): + # Compose an extra rotation (by source_id, over the full head_dim) into + # the spatial rope. source_id == 0 -> identity (target unchanged). + if source_id == 0: + return freqs + d = self.dim // self.num_heads + pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) + id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) + return torch.einsum('...ij,...jk->...ik', freqs, id_rot) + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): + freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + return self._source_id_freqs(freqs, source_id) + + def _bernini_conditions(self, kwargs): + # Returns [(latent[B,C,T,H,W], source_id), ...] in concat order: + # source video first (source_id 1), then each reference image (2, 3, ...). + specs = [] + sid = 1 + video = kwargs.get("bernini_video_latent", None) + if video is not None: + specs.append((video, sid)) + sid += 1 + images = kwargs.get("bernini_image_latents", None) + if images is not None: + for i in range(images.shape[2]): + specs.append((images[:, :, i:i + 1], sid)) + sid += 1 + return specs + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + specs = [(comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size), sid) + for lat, sid in self._bernini_conditions(kwargs)] + + # Target rope (source_id 0) first, then one block per condition stream. + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=0) + for lat, sid in specs: + cf = self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=sid) + freqs = torch.cat([freqs, cf], dim=1) + + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, bernini_cond_specs=specs, **kwargs)[:, :, :t, :h, :w] + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, bernini_cond_specs=None, **kwargs): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + target_len = x.shape[1] + + # in-context conditions: patch-embed and append (matching freqs order) + if bernini_cond_specs: + for lat, _ in bernini_cond_specs: + cond = self.patch_embedding(lat.float().to(x.device)).to(x.dtype) + x = torch.cat([x, cond.flatten(2).transpose(1, 2)], dim=1) + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + # drop the appended condition tokens, keep the target + if bernini_cond_specs: + x = x[:, :target_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 205178911..9afb80ff9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1708,6 +1708,32 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN22_Bernini(WAN22): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.BerniniWanModel) + self.image_to_video = image_to_video + self.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + video = kwargs.get("bernini_video_latent", None) + if video is not None: + out["bernini_video_latent"] = comfy.conds.CONDRegular(self.process_latent_in(video)) + images = kwargs.get("bernini_image_latents", None) + if images is not None: + out["bernini_image_latents"] = comfy.conds.CONDRegular(self.process_latent_in(images)) + return out + + def extra_conds_shapes(self, **kwargs): + out = super().extra_conds_shapes(**kwargs) + video = kwargs.get("bernini_video_latent", None) + if video is not None: + out["bernini_video_latent"] = video.shape + images = kwargs.get("bernini_image_latents", None) + if images is not None: + out["bernini_image_latents"] = images.shape + return out + class WAN21_FlowRVS(WAN21): def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): model_config.unet_config["model_type"] = "t2v" diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py new file mode 100644 index 000000000..4bfb05c43 --- /dev/null +++ b/comfy_extras/nodes_bernini.py @@ -0,0 +1,95 @@ +import torch + +import comfy.ldm.wan.model +import comfy.model_base +import comfy.model_management +import comfy.utils +import node_helpers + + +def _patch_bernini(model): + """Flip a loaded Wan2.2-A14B model into Bernini-R mode. + + The Bernini checkpoint is architecturally identical to Wan2.2-A14B (no new + params), so we just swap the forward (BerniniWanModel) and the conditioning + plumbing (WAN22_Bernini) onto the already-loaded model. Idempotent. + """ + model.model.diffusion_model.__class__ = comfy.ldm.wan.model.BerniniWanModel + model.model.__class__ = comfy.model_base.WAN22_Bernini + model.model.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents") + return model + + +def _encode_frames(vae, image, width, height): + image = comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + return vae.encode(image[:, :, :, :3]) + + +class BerniniConditioning: + """Routes Bernini-R inputs and activates Bernini mode on the model(s). + + Attaches the VAE-encoded source video / reference images to BOTH the + positive and negative conditioning so stock CFG keeps the conditions fixed + and only varies the text -- giving Bernini's v2v / rv2v guidance form. For + cfg=1.0 (distill LoRA) the same setup is a single forward with the full + conditioning. t2v attaches nothing. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "vae": ("VAE",), + "task_type": (["t2v", "v2v", "rv2v"],), + "width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": { + "model_low": ("MODEL",), + "source_video": ("IMAGE",), + "reference_images": ("IMAGE",), + }, + } + + RETURN_TYPES = ("MODEL", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("model", "model_low", "positive", "negative", "latent") + FUNCTION = "build" + CATEGORY = "conditioning/video_models" + + def build(self, model, positive, negative, vae, task_type, width, height, length, batch_size, + model_low=None, source_video=None, reference_images=None): + model = _patch_bernini(model) + if model_low is not None: + model_low = _patch_bernini(model_low) + + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + values = {} + if task_type in ("v2v", "rv2v") and source_video is not None: + values["bernini_video_latent"] = _encode_frames(vae, source_video[:length], width, height) + + if task_type == "rv2v" and reference_images is not None: + # each reference image is an independent single-frame stream (its own source_id) + refs = [_encode_frames(vae, reference_images[i:i + 1], width, height) for i in range(reference_images.shape[0])] + values["bernini_image_latents"] = torch.cat(refs, dim=2) + + if values: + positive = node_helpers.conditioning_set_values(positive, values) + negative = node_helpers.conditioning_set_values(negative, values) + + return (model, model_low, positive, negative, {"samples": latent}) + + +NODE_CLASS_MAPPINGS = { + "BerniniConditioning": BerniniConditioning, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "BerniniConditioning": "Bernini Conditioning", +} diff --git a/nodes.py b/nodes.py index 5678bc22d..61f4071ab 100644 --- a/nodes.py +++ b/nodes.py @@ -2403,6 +2403,7 @@ async def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_bernini.py", "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py",