mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 17:59:54 +08:00
better
This commit is contained in:
parent
19475fdafc
commit
bb272ea09f
@ -570,6 +570,17 @@ class WanModel(torch.nn.Module):
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# In-context reference streams (e.g. Bernini source video / ref images):
|
||||
# patch-embed each clean condition latent and append as extra tokens (their
|
||||
# rope, with per-stream source_id, was appended to `freqs` in _forward).
|
||||
# Inert when no context_latents are supplied.
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
main_len = x.shape[1]
|
||||
if context_latents is not None:
|
||||
for lat in context_latents:
|
||||
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x, cl], dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@ -599,6 +610,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if context_latents is not None:
|
||||
x = x[:, :main_len]
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
@ -606,7 +620,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@ -638,6 +652,16 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
# In-context reference conditioning (e.g. Bernini): a non-zero source_id
|
||||
# composes an extra rotation (over the full head_dim) into the spatial
|
||||
# rope so streams sharing the same spatial coords stay distinct. source_id
|
||||
# 0 is identity, so this is a no-op for all normal Wan usage.
|
||||
if source_id:
|
||||
d = self.dim // self.num_heads
|
||||
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
|
||||
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
|
||||
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
@ -661,6 +685,16 @@ class WanModel(torch.nn.Module):
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
# In-context reference streams: one rope block per stream, each with its
|
||||
# own source_id (1, 2, ...) so they stay distinct from the target (id 0).
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
|
||||
for i, lat in enumerate(context_latents):
|
||||
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
|
||||
kwargs = {**kwargs, "context_latents": context_latents}
|
||||
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@ -1739,126 +1773,3 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class BerniniWanModel(WanModel):
|
||||
"""Wan2.2-A14B fine-tune (ByteDance Bernini-R) with in-context conditioning.
|
||||
|
||||
Source video / reference image latents are patch-embedded with the same
|
||||
``patch_embedding`` as the noisy target and concatenated as extra tokens
|
||||
along the sequence. Each conditioning stream carries a ``source_id`` (target
|
||||
= 0, conditions = 1, 2, ...) realised as an extra multiplicative rotary
|
||||
factor composed into the spatial RoPE: spatial coordinates overlap across
|
||||
streams, only the source_id separates them. Self-attention is full over the
|
||||
concatenated sequence; the target tokens are sliced back out afterwards.
|
||||
|
||||
The condition latents arrive as kwargs (``bernini_video_latent``,
|
||||
``bernini_image_latents``) from ``WAN22_Bernini.extra_conds``.
|
||||
"""
|
||||
|
||||
def _source_id_freqs(self, freqs, source_id):
|
||||
# Compose an extra rotation (by source_id, over the full head_dim) into
|
||||
# the spatial rope. source_id == 0 -> identity (target unchanged).
|
||||
if source_id == 0:
|
||||
return freqs
|
||||
d = self.dim // self.num_heads
|
||||
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
|
||||
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
|
||||
return torch.einsum('...ij,...jk->...ik', freqs, id_rot)
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
|
||||
freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
return self._source_id_freqs(freqs, source_id)
|
||||
|
||||
def _bernini_conditions(self, kwargs):
|
||||
# Returns [(latent[B,C,T,H,W], source_id), ...] in concat order:
|
||||
# source video first (source_id 1), then each reference image (2, 3, ...).
|
||||
specs = []
|
||||
sid = 1
|
||||
video = kwargs.get("bernini_video_latent", None)
|
||||
if video is not None:
|
||||
specs.append((video, sid))
|
||||
sid += 1
|
||||
images = kwargs.get("bernini_image_latents", None)
|
||||
if images is not None:
|
||||
for i in range(images.shape[2]):
|
||||
specs.append((images[:, :, i:i + 1], sid))
|
||||
sid += 1
|
||||
return specs
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
specs = [(comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size), sid)
|
||||
for lat, sid in self._bernini_conditions(kwargs)]
|
||||
|
||||
# Target rope (source_id 0) first, then one block per condition stream.
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=0)
|
||||
for lat, sid in specs:
|
||||
cf = self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=sid)
|
||||
freqs = torch.cat([freqs, cf], dim=1)
|
||||
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, bernini_cond_specs=specs, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, bernini_cond_specs=None, **kwargs):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
target_len = x.shape[1]
|
||||
|
||||
# in-context conditions: patch-embed and append (matching freqs order)
|
||||
if bernini_cond_specs:
|
||||
for lat, _ in bernini_cond_specs:
|
||||
cond = self.patch_embedding(lat.float().to(x.device)).to(x.dtype)
|
||||
x = torch.cat([x, cond.flatten(2).transpose(1, 2)], dim=1)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None:
|
||||
if self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# drop the appended condition tokens, keep the target
|
||||
if bernini_cond_specs:
|
||||
x = x[:, :target_len]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
@ -1516,6 +1516,13 @@ class WAN21(BaseModel):
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
# In-context reference conditioning (source video / reference images,
|
||||
# e.g. Bernini): a list of clean latents appended as extra token streams
|
||||
# with per-stream source_id rope. Inert when not supplied.
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -1708,32 +1715,6 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN22_Bernini(WAN22):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.BerniniWanModel)
|
||||
self.image_to_video = image_to_video
|
||||
self.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents")
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
video = kwargs.get("bernini_video_latent", None)
|
||||
if video is not None:
|
||||
out["bernini_video_latent"] = comfy.conds.CONDRegular(self.process_latent_in(video))
|
||||
images = kwargs.get("bernini_image_latents", None)
|
||||
if images is not None:
|
||||
out["bernini_image_latents"] = comfy.conds.CONDRegular(self.process_latent_in(images))
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
video = kwargs.get("bernini_video_latent", None)
|
||||
if video is not None:
|
||||
out["bernini_video_latent"] = video.shape
|
||||
images = kwargs.get("bernini_image_latents", None)
|
||||
if images is not None:
|
||||
out["bernini_image_latents"] = images.shape
|
||||
return out
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
|
||||
@ -1,89 +1,86 @@
|
||||
import torch
|
||||
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.model_base
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
|
||||
|
||||
def _patch_bernini(model):
|
||||
"""Flip a loaded Wan2.2-A14B model into Bernini-R mode.
|
||||
|
||||
The Bernini checkpoint is architecturally identical to Wan2.2-A14B (no new
|
||||
params), so we just swap the forward (BerniniWanModel) and the conditioning
|
||||
plumbing (WAN22_Bernini) onto the already-loaded model. Idempotent.
|
||||
"""
|
||||
model.model.diffusion_model.__class__ = comfy.ldm.wan.model.BerniniWanModel
|
||||
model.model.__class__ = comfy.model_base.WAN22_Bernini
|
||||
model.model.memory_usage_factor_conds = ("bernini_video_latent", "bernini_image_latents")
|
||||
return model
|
||||
|
||||
|
||||
def _encode_frames(vae, image, width, height):
|
||||
image = comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
return vae.encode(image[:, :, :, :3])
|
||||
def _resize_long_edge(image, max_size, stride=16):
|
||||
"""Resize (preserve aspect) so the long edge <= max_size, snapped to `stride`."""
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = min(max_size / max(h, w), 1.0)
|
||||
nh = max(stride, round(h * scale / stride) * stride)
|
||||
nw = max(stride, round(w * scale / stride) * stride)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
class BerniniConditioning:
|
||||
"""Routes Bernini-R inputs and activates Bernini mode on the model(s).
|
||||
"""Bernini-R in-context conditioning for a Wan2.2-A14B model.
|
||||
|
||||
Attaches the VAE-encoded source video / reference images to BOTH the
|
||||
positive and negative conditioning so stock CFG keeps the conditions fixed
|
||||
and only varies the text -- giving Bernini's v2v / rv2v guidance form. For
|
||||
cfg=1.0 (distill LoRA) the same setup is a single forward with the full
|
||||
conditioning. t2v attaches nothing.
|
||||
Attaches the VAE-encoded source video / reference images to BOTH positive and
|
||||
negative conditioning as ``context_latents`` -- an ordered list of clean
|
||||
latent streams (source video first, then each reference image), which the Wan
|
||||
model appends as extra tokens with per-stream source_id rope. With stock CFG
|
||||
the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA)
|
||||
it's a single forward over the full conditioning.
|
||||
|
||||
The task is inferred from which inputs are connected -- no model input and no
|
||||
task selector needed; the model loads as a normal Wan2.2 checkpoint via the
|
||||
stock UNETLoader:
|
||||
(nothing) -> t2v
|
||||
source_video -> v2v
|
||||
source_video + ref images -> rv2v
|
||||
ref images only -> r2v (each kept at native aspect)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"task_type": (["t2v", "v2v", "rv2v"],),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {
|
||||
"model_low": ("MODEL",),
|
||||
"source_video": ("IMAGE",),
|
||||
"reference_images": ("IMAGE",),
|
||||
"ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("model", "model_low", "positive", "negative", "latent")
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def build(self, model, positive, negative, vae, task_type, width, height, length, batch_size,
|
||||
model_low=None, source_video=None, reference_images=None):
|
||||
model = _patch_bernini(model)
|
||||
if model_low is not None:
|
||||
model_low = _patch_bernini(model_low)
|
||||
|
||||
def build(self, positive, negative, vae, width, height, length, batch_size,
|
||||
source_video=None, reference_images=None, ref_max_size=848):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
values = {}
|
||||
if task_type in ("v2v", "rv2v") and source_video is not None:
|
||||
values["bernini_video_latent"] = _encode_frames(vae, source_video[:length], width, height)
|
||||
# Ordered list of condition streams: source video (source_id 1) first,
|
||||
# then each reference image (source_id 2, 3, ...). The model assigns the
|
||||
# source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by
|
||||
# which inputs are present.
|
||||
context = []
|
||||
if source_video is not None:
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
context.append(vae.encode(vid[:, :, :, :3]))
|
||||
|
||||
if task_type == "rv2v" and reference_images is not None:
|
||||
# each reference image is an independent single-frame stream (its own source_id)
|
||||
refs = [_encode_frames(vae, reference_images[i:i + 1], width, height) for i in range(reference_images.shape[0])]
|
||||
values["bernini_image_latents"] = torch.cat(refs, dim=2)
|
||||
if reference_images is not None:
|
||||
for i in range(reference_images.shape[0]):
|
||||
img = _resize_long_edge(reference_images[i:i + 1], ref_max_size) # native aspect per ref
|
||||
context.append(vae.encode(img[:, :, :, :3]))
|
||||
|
||||
if values:
|
||||
positive = node_helpers.conditioning_set_values(positive, values)
|
||||
negative = node_helpers.conditioning_set_values(negative, values)
|
||||
if context:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
|
||||
|
||||
return (model, model_low, positive, negative, {"samples": latent})
|
||||
return (positive, negative, {"samples": latent})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user