mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 17:59:54 +08:00
Cleanup
This commit is contained in:
parent
886b2e5102
commit
46ba987361
@ -570,10 +570,7 @@ class WanModel(torch.nn.Module):
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# In-context reference streams (e.g. Bernini source video / ref images):
|
||||
# patch-embed each clean condition latent and append as extra tokens (their
|
||||
# rope, with per-stream source_id, was appended to `freqs` in _forward).
|
||||
# Inert when no context_latents are supplied.
|
||||
# In-context reference streams (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
main_len = x.shape[1]
|
||||
if context_latents is not None:
|
||||
|
||||
@ -1517,9 +1517,7 @@ class WAN21(BaseModel):
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
# In-context reference conditioning (source video / reference images,
|
||||
# e.g. Bernini): a list of clean latents appended as extra token streams
|
||||
# with per-stream source_id rope. Inert when not supplied.
|
||||
# In-context reference conditioning (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def _resize_long_edge(image, max_size, stride=16):
|
||||
@ -11,22 +13,17 @@ def _resize_long_edge(image, max_size, stride=16):
|
||||
scale = min(max_size / max(h, w), 1.0)
|
||||
nh = max(stride, round(h * scale / stride) * stride)
|
||||
nw = max(stride, round(w * scale / stride) * stride)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "bilinear", "disabled").movedim(1, -1)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
class BerniniConditioning:
|
||||
"""Bernini-R in-context conditioning for a Wan2.2-A14B model.
|
||||
class BerniniConditioning(io.ComfyNode):
|
||||
"""Bernini in-context conditioning for a Wan2.2-A14B model.
|
||||
|
||||
Attaches the VAE-encoded source video / reference images to BOTH positive and
|
||||
negative conditioning as ``context_latents`` -- an ordered list of clean
|
||||
latent streams (source video first, then each reference image), which the Wan
|
||||
model appends as extra tokens with per-stream source_id rope. With stock CFG
|
||||
the conditions stay fixed and only the text varies; at cfg=1.0 (distill LoRA)
|
||||
it's a single forward over the full conditioning.
|
||||
Attaches the VAE-encoded source video / reference images to the conditioning
|
||||
an ordered list of clean latents (source video first, then each reference image),
|
||||
which the Wan model appends as extra tokens with per-stream source_id rope.
|
||||
|
||||
The task is inferred from which inputs are connected -- no model input and no
|
||||
task selector needed; the model loads as a normal Wan2.2 checkpoint via the
|
||||
stock UNETLoader:
|
||||
The task is inferred from which inputs are connected:
|
||||
(nothing) -> t2v
|
||||
source_video -> v2v
|
||||
source_video + ref images -> rv2v
|
||||
@ -34,41 +31,43 @@ class BerniniConditioning:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 832, "min": 16, "max": 8192, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": 8192, "step": 16}),
|
||||
"length": ("INT", {"default": 81, "min": 1, "max": 8192, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {
|
||||
"source_video": ("IMAGE",),
|
||||
"reference_images": ("IMAGE",),
|
||||
"ref_max_size": ("INT", {"default": 848, "min": 16, "max": 8192, "step": 16}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BerniniConditioning",
|
||||
display_name="Bernini Conditioning",
|
||||
category="conditioning/video_models",
|
||||
description="Conditioning node for Bernini in-context video/image conditioning. Attach source video and/or reference images to the positive/negative conditioning, "
|
||||
"which the Wan model will append as extra tokens with per-stream source_id rope.",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=8192, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("source_video", optional=True, tooltip="Source video to edit/restyle (original task v2v or rv2v). Resized to width/height and trimmed to length."),
|
||||
io.Image.Input("reference_images", optional=True, tooltip="Reference image(s) injected as in-context tokens (task r2v or rv2v). Each is kept at its native aspect ratio, long edge capped at ref_max_size."),
|
||||
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
FUNCTION = "build"
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def build(self, positive, negative, vae, width, height, length, batch_size,
|
||||
source_video=None, reference_images=None, ref_max_size=848):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size,
|
||||
source_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
# Ordered list of condition streams: source video (source_id 1) first,
|
||||
# then each reference image (source_id 2, 3, ...). The model assigns the
|
||||
# source_id from list order. The task (t2v/v2v/rv2v/r2v) is implied by
|
||||
# which inputs are present.
|
||||
# then each reference image (source_id 2, 3, ...), the model assigns the source_id from list order.
|
||||
context = []
|
||||
if source_video is not None:
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
context.append(vae.encode(vid[:, :, :, :3]))
|
||||
|
||||
if reference_images is not None:
|
||||
@ -80,13 +79,16 @@ class BerniniConditioning:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
|
||||
|
||||
return (positive, negative, {"samples": latent})
|
||||
return io.NodeOutput(positive, negative, {"samples": latent})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"BerniniConditioning": BerniniConditioning,
|
||||
}
|
||||
class BerniniExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
BerniniConditioning,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"BerniniConditioning": "Bernini Conditioning",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> BerniniExtension:
|
||||
return BerniniExtension()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user