Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-08-18 07:35:15 +09:00
commit 8b44e58e6c
14 changed files with 573 additions and 347 deletions

View File

@ -224,19 +224,27 @@ class Flux(nn.Module):
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
for ref in ref_latents: for ref in ref_latents:
h_offset = 0 if index_ref_method:
w_offset = 0 index += 1
if ref.shape[-2] + h > ref.shape[-1] + w: h_offset = 0
w_offset = w w_offset = 0
else: else:
h_offset = h index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))

View File

@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def pos_embeds(self, x, context): def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len)) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) img_ids[:, :, 0] = img_ids[:, :, 1] + index
ids = torch.cat((txt_ids, img_ids), dim=1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward( def forward(
self, self,
@ -356,19 +360,46 @@ class QwenImageTransformer2DModel(nn.Module):
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
**kwargs **kwargs
): ):
timestep = timesteps timestep = timesteps
encoder_hidden_states = context encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) if ref_latents is not None:
orig_shape = hidden_states.shape h = 0
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) w = 0
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) index = 0
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -383,18 +414,30 @@ class QwenImageTransformer2DModel(nn.Module):
else self.time_text_embed(timestep, guidance, hidden_states) else self.time_text_embed(timestep, guidance, hidden_states)
) )
for block in self.transformer_blocks: patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states, hidden_states = block( blocks_replace = patches_replace.get("dit", {})
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states_mask=encoder_hidden_states_mask, if ("double_block", i) in blocks_replace:
temb=temb, def block_wrap(args):
image_rotary_emb=image_rotary_emb, out = {}
) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -768,7 +768,12 @@ class CameraWanModel(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) if model_type == 'camera':
model_type = 'i2v'
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)

View File

@ -890,6 +890,10 @@ class Flux(BaseModel):
for lat in ref_latents: for lat in ref_latents:
latents.append(self.process_latent_in(lat)) latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents) out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out return out
def extra_conds_shapes(self, **kwargs): def extra_conds_shapes(self, **kwargs):
@ -1327,4 +1331,14 @@ class QwenImage(BaseModel):
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out return out

View File

@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera" if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
else: else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v" dit_config["model_type"] = "i2v"

View File

@ -32,18 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try: try:
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
SDPA_BACKEND_PRIORITY = [ SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs):
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
def scaled_dot_product_attention(q, k, v, *args, **kwargs): else:
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError): except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.") logging.warning("Could not set sdpa backend priority.")

View File

@ -1,6 +1,7 @@
import torch import torch
import comfy.model_management import comfy.model_management
import numbers import numbers
import logging
RMSNorm = None RMSNorm = None
@ -9,6 +10,7 @@ try:
RMSNorm = torch.nn.RMSNorm RMSNorm = torch.nn.RMSNorm
except: except:
rms_norm_torch = None rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):

View File

@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device) out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out return out
class WAN22_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera_2.2",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V): class WAN21_Vace(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -1,6 +1,5 @@
import logging import logging
from typing import Any, Callable, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar
import random
import torch import torch
from comfy_api_nodes.util.validation_utils import ( from comfy_api_nodes.util.validation_utils import (
get_image_dimensions, get_image_dimensions,
@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _validate_video_dimensions(width: int, height: int) -> None: def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements.""" """Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = { supported_resolutions = {
(1920, 1080), (1080, 1920), (1152, 1152), (1920, 1080),
(1536, 1152), (1152, 1536) (1080, 1920),
(1152, 1152),
(1536, 1152),
(1152, 1536),
} }
if (width, height) not in supported_resolutions: if (width, height) not in supported_resolutions:
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) supported_list = ", ".join(
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") [f"{w}x{h}" for w, h in sorted(supported_resolutions)]
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
def _validate_container_format(video: VideoInput) -> None: def _validate_container_format(video: VideoInput) -> None:
"""Validates video container format is MP4.""" """Validates video container format is MP4."""
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}") raise ValueError(
f"Only MP4 container format supported. Got: {container_format}"
)
def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
""" """
Returns a new VideoInput object trimmed from the beginning to the specified duration, Returns a new VideoInput object trimmed from the beginning to the specified duration,
@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
# Calculate target frame count that's divisible by 16 # Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps) estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0: if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley") raise ValueError("Video too short: need at least 16 frames for Moonvalley")
@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts", default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
), ),
"resolution": ( "resolution": (
IO.COMBO, IO.COMBO,
@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
"tooltip": "Resolution of the output video", "tooltip": "Resolution of the output video",
}, },
), ),
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
"prompt_adherence": model_field_to_node_input( "prompt_adherence": model_field_to_node_input(
IO.FLOAT, IO.FLOAT,
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
"guidance_scale", "guidance_scale",
default=7.0, default=10.0,
step=1, step=1,
min=1, min=1,
max=20, max=20,
@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
IO.INT, IO.INT,
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
"seed", "seed",
default=random.randint(0, 2**32 - 1), default=9,
min=0, min=0,
max=4294967295, max=4294967295,
step=1, step=1,
display="number", display="number",
tooltip="Random seed value", tooltip="Random seed value",
control_after_generate=True,
), ),
"steps": model_field_to_node_input( "steps": model_field_to_node_input(
IO.INT, IO.INT,
@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
image_url = (await upload_images_to_comfyapi( image_url = (
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type await upload_images_to_comfyapi(
))[0] image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
)
)[0]
request = MoonvalleyTextToVideoRequest( request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params image_url=image_url, prompt_text=prompt, inference_params=inference_params
@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
return { return {
"required": { "required": {
"prompt": model_field_to_node_input( "prompt": model_field_to_node_input(
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", IO.STRING,
multiline=True MoonvalleyVideoToVideoRequest,
"prompt_text",
multiline=True,
), ),
"negative_prompt": model_field_to_node_input( "negative_prompt": model_field_to_node_input(
IO.STRING, IO.STRING,
MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoInferenceParams,
"negative_prompt", "negative_prompt",
multiline=True, multiline=True,
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyVideoToVideoInferenceParams,
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
tooltip="Random seed value",
control_after_generate=False,
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyVideoToVideoInferenceParams,
"guidance_scale",
default=10.0,
step=1,
min=1,
max=20,
), ),
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
}, },
"hidden": { "hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG", "auth_token": "AUTH_TOKEN_COMFY_ORG",
@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"unique_id": "UNIQUE_ID", "unique_id": "UNIQUE_ID",
}, },
"optional": { "optional": {
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), "video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
},
),
"control_type": ( "control_type": (
["Motion Transfer", "Pose Transfer"], ["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"}, {"default": "Motion Transfer"},
@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
"max": 100, "max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'", "tooltip": "Only used if control_type is 'Motion Transfer'",
}, },
) ),
} "image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
} }
RETURN_TYPES = ("VIDEO",) RETURN_TYPES = ("VIDEO",)
@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
): ):
video = kwargs.get("video") video = kwargs.get("video")
image = kwargs.get("image", None)
if not video: if not video:
raise MoonvalleyApiError("video is required") raise MoonvalleyApiError("video is required")
@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
video_url = "" video_url = ""
if video: if video:
validated_video = validate_video_to_video_input(video) validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) video_url = await upload_video_to_comfyapi(
validated_video, auth_kwargs=kwargs
)
mime_type = "image/png"
if not image is None:
validate_input_image(image, with_frame_conditioning=True)
image_url = await upload_images_to_comfyapi(
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
)
control_type = kwargs.get("control_type") control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity") motion_intensity = kwargs.get("motion_intensity")
@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
# Only include motion_intensity for Motion Transfer # Only include motion_intensity for Motion Transfer
control_params = {} control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None: if control_type == "Motion Transfer" and motion_intensity is not None:
control_params['motion_intensity'] = motion_intensity control_params["motion_intensity"] = motion_intensity
inference_params=MoonvalleyVideoToVideoInferenceParams( inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=kwargs.get("seed"), seed=kwargs.get("seed"),
control_params=control_params control_params=control_params,
) )
control = self.parseControlParameter(control_type) control = self.parseControlParameter(control_type)
@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
prompt_text=prompt, prompt_text=prompt,
inference_params=inference_params, inference_params=inference_params,
) )
request.image_url = image_url if not image is None else None
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
@ -694,15 +748,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
inference_params=MoonvalleyTextToVideoInferenceParams( inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
steps=kwargs.get("steps"), steps=kwargs.get("steps"),
seed=kwargs.get("seed"), seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"), guidance_scale=kwargs.get("prompt_adherence"),
num_frames=128, num_frames=128,
width=width_height.get("width"), width=width_height.get("width"),
height=width_height.get("height"), height=width_height.get("height"),
) )
request = MoonvalleyTextToVideoRequest( request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params prompt_text=prompt, inference_params=inference_params
) )

View File

@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
path = "/proxy/openai/images/generations" path = "/proxy/openai/images/generations"
content_type = "application/json" content_type = "application/json"
request_class = OpenAIImageGenerationRequest request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = [] files = []
if image is not None: if image is not None:
@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
img_byte_arr = io.BytesIO() img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG") img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0) img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1: if batch_size == 1:
files.append(("image", img_binary)) files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else: else:
files.append(("image[]", img_binary)) files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None: if mask is not None:
if image is None: if image is None:
@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask_img_byte_arr = io.BytesIO() mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG") mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0) mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
# Build the operation # Build the operation
operation = SynchronousOperation( operation = SynchronousOperation(

View File

@ -346,6 +346,24 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio) return "Invalid audio file: {}".format(audio)
return True return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio, "EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio, "VAEEncodeAudio": VAEEncodeAudio,
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
"LoadAudio": LoadAudio, "LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudio": "Save Audio (FLAC)", "SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)", "SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)", "SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
} }

View File

@ -100,9 +100,28 @@ class FluxKontextImageScale:
return (image, ) return (image, )
class FluxKontextMultiReferenceLatentMethod:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux, "CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance, "FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance, "FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale, "FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
} }

View File

@ -9,29 +9,35 @@ import comfy.clip_vision
import json import json
import numpy as np import numpy as np
from typing import Tuple from typing import Tuple
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class WanImageToVideo: class WanImageToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanImageToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"start_image": ("IMAGE", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
}} io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None: if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -51,32 +57,36 @@ class WanImageToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class WanFunControlToVideo: class WanFunControlToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanFunControlToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"start_image": ("IMAGE", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"control_video": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
@ -101,31 +111,34 @@ class WanFunControlToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class Wan22FunControlToVideo: class Wan22FunControlToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="Wan22FunControlToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"ref_image": ("IMAGE", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"control_video": ("IMAGE", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
# "start_image": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
@ -158,32 +171,36 @@ class Wan22FunControlToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class WanFirstLastFrameToVideo: class WanFirstLastFrameToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanFirstLastFrameToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"start_image": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
"end_image": ("IMAGE", ), io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
}} io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None: if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -224,62 +241,70 @@ class WanFirstLastFrameToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class WanFunInpaintToVideo: class WanFunInpaintToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanFunInpaintToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"start_image": ("IMAGE", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"end_image": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
flfv = WanFirstLastFrameToVideo() flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
class WanVaceToVideo: class WanVaceToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanVaceToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), is_experimental=True,
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("positive"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Conditioning.Input("negative"),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"control_video": ("IMAGE", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"control_masks": ("MASK", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"reference_image": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01),
io.Image.Input("control_video", optional=True),
io.Mask.Input("control_masks", optional=True),
io.Image.Input("reference_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
EXPERIMENTAL = True
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
latent_length = ((length - 1) // 4) + 1 latent_length = ((length - 1) // 4) + 1
if control_video is not None: if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
@ -336,52 +361,59 @@ class WanVaceToVideo:
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent, trim_latent) return io.NodeOutput(positive, negative, out_latent, trim_latent)
class TrimVideoLatent: class TrimVideoLatent(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), node_id="TrimVideoLatent",
}} category="latent/video",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.Int.Input("trim_amount", default=0, min=0, max=99999),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, trim_amount) -> io.NodeOutput:
CATEGORY = "latent/video"
EXPERIMENTAL = True
def op(self, samples, trim_amount):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
samples_out["samples"] = s1[:, :, trim_amount:] samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,) return io.NodeOutput(samples_out)
class WanCameraImageToVideo: class WanCameraImageToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanCameraImageToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"start_image": ("IMAGE", ), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}} io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("start_image", optional=True),
io.WanCameraEmbedding.Input("camera_conditions", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
@ -390,9 +422,12 @@ class WanCameraImageToVideo:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
mask[:, :, :start_image.shape[0] + 3] = 0.0
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
if camera_conditions is not None: if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
@ -404,29 +439,34 @@ class WanCameraImageToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class WanPhantomSubjectToVideo: class WanPhantomSubjectToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="WanPhantomSubjectToVideo",
"vae": ("VAE", ), category="conditioning/video_models",
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), inputs=[
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("positive"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.Conditioning.Input("negative"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"optional": {"images": ("IMAGE", ), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
}} io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("images", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative_text"),
io.Conditioning.Output(display_name="negative_img_text"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond2 = negative cond2 = negative
if images is not None: if images is not None:
@ -442,7 +482,7 @@ class WanPhantomSubjectToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, cond2, negative, out_latent) return io.NodeOutput(positive, cond2, negative, out_latent)
def parse_json_tracks(tracks): def parse_json_tracks(tracks):
"""Parse JSON track data into a standardized format""" """Parse JSON track data into a standardized format"""
@ -655,39 +695,41 @@ def patch_motion(
return out_mask_full, out_feature_full return out_mask_full, out_feature_full
class WanTrackToVideo: class WanTrackToVideo(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"positive": ("CONDITIONING", ), node_id="WanPhantomSubjectToVideo",
"negative": ("CONDITIONING", ), category="conditioning/video_models",
"vae": ("VAE", ), inputs=[
"tracks": ("STRING", {"multiline": True, "default": "[]"}), io.Conditioning.Input("positive"),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Conditioning.Input("negative"),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), io.Vae.Input("vae"),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), io.String.Input("tracks", multiline=True, default="[]"),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
"topk": ("INT", {"default": 2, "min": 1, "max": 10}), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
"start_image": ("IMAGE", ), io.Int.Input("batch_size", default=1, min=1, max=4096),
}, io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
"optional": { io.Int.Input("topk", default=2, min=1, max=10),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), io.Image.Input("start_image"),
}} io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") @classmethod
RETURN_NAMES = ("positive", "negative", "latent") def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size,
FUNCTION = "encode" temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput:
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
temperature, topk, start_image=None, clip_vision_output=None):
tracks_data = parse_json_tracks(tracks) tracks_data = parse_json_tracks(tracks)
if not tracks_data: if not tracks_data:
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device()) device=comfy.model_management.intermediate_device())
@ -741,34 +783,36 @@ class WanTrackToVideo:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
class Wan22ImageToVideoLatent: class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"vae": ("VAE", ), return io.Schema(
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), node_id="Wan22ImageToVideoLatent",
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), category="conditioning/inpaint",
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), inputs=[
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), io.Vae.Input("vae"),
}, io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32),
"optional": {"start_image": ("IMAGE", ), io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32),
}} io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
RETURN_TYPES = ("LATENT",) def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None):
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
if start_image is None: if start_image is None:
out_latent = {} out_latent = {}
out_latent["samples"] = latent out_latent["samples"] = latent
return (out_latent,) return io.NodeOutput(out_latent)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
@ -783,19 +827,25 @@ class Wan22ImageToVideoLatent:
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,) return io.NodeOutput(out_latent)
NODE_CLASS_MAPPINGS = { class WanExtension(ComfyExtension):
"WanTrackToVideo": WanTrackToVideo, @override
"WanImageToVideo": WanImageToVideo, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"WanFunControlToVideo": WanFunControlToVideo, return [
"Wan22FunControlToVideo": Wan22FunControlToVideo, WanTrackToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo, WanImageToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, WanFunControlToVideo,
"WanVaceToVideo": WanVaceToVideo, Wan22FunControlToVideo,
"TrimVideoLatent": TrimVideoLatent, WanFunInpaintToVideo,
"WanCameraImageToVideo": WanCameraImageToVideo, WanFirstLastFrameToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, WanVaceToVideo,
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, TrimVideoLatent,
} WanCameraImageToVideo,
WanPhantomSubjectToVideo,
Wan22ImageToVideoLatent,
]
async def comfy_entrypoint() -> WanExtension:
return WanExtension()

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.24.4 comfyui-frontend-package==1.25.8
comfyui-workflow-templates==0.1.59 comfyui-workflow-templates==0.1.60
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.2.6
comfyui_manager comfyui_manager
torch torch