ComfyUI/comfy_extras/nodes_bernini.py
2026-06-01 19:18:33 +03:00

96 lines
3.8 KiB
Python

import torch
import comfy.ldm.wan.model
import comfy.model_base
import comfy.model_management
import comfy.utils
import node_helpers
def _patch_bernini(model):
"""Flip a loaded Wan2.2-A14B model into Bernini-R mode.
The Bernini checkpoint is architecturally identical to Wan2.2-A14B (no new
params), so we just swap the forward (BerniniWanModel) and the conditioning
plumbing (WAN22_Bernini) onto the already-loaded model. Idempotent.
"""
model.model.diffusion_model.__class__ = comfy.ldm.wan.model.BerniniWanModel
model.model.__class__ = comfy.model_base.WAN22_Bernini
model.model.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents")
return model
def _encode_frames(vae, image, width, height):
image = comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
return vae.encode(image[:, :, :, :3])
class BerniniConditioning:
"""Routes Bernini-R inputs and activates Bernini mode on the model(s).
Attaches the VAE-encoded source video / reference images to BOTH the
positive and negative conditioning so stock CFG keeps the conditions fixed
and only varies the text -- giving Bernini's v2v / rv2v guidance form. For
cfg=1.0 (distill LoRA) the same setup is a single forward with the full
conditioning. t2v attaches nothing.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"task_type": (["t2v", "v2v", "rv2v"],),
"width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {
"model_low": ("MODEL",),
"source_video": ("IMAGE",),
"reference_images": ("IMAGE",),
},
}
RETURN_TYPES = ("MODEL", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("model", "model_low", "positive", "negative", "latent")
FUNCTION = "build"
CATEGORY = "conditioning/video_models"
def build(self, model, positive, negative, vae, task_type, width, height, length, batch_size,
model_low=None, source_video=None, reference_images=None):
model = _patch_bernini(model)
if model_low is not None:
model_low = _patch_bernini(model_low)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
values = {}
if task_type in ("v2v", "rv2v") and source_video is not None:
values["bernini_video_latent"] = _encode_frames(vae, source_video[:length], width, height)
if task_type == "rv2v" and reference_images is not None:
# each reference image is an independent single-frame stream (its own source_id)
refs = [_encode_frames(vae, reference_images[i:i + 1], width, height) for i in range(reference_images.shape[0])]
values["bernini_image_latents"] = torch.cat(refs, dim=2)
if values:
positive = node_helpers.conditioning_set_values(positive, values)
negative = node_helpers.conditioning_set_values(negative, values)
return (model, model_low, positive, negative, {"samples": latent})
NODE_CLASS_MAPPINGS = {
"BerniniConditioning": BerniniConditioning,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BerniniConditioning": "Bernini Conditioning",
}