diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 9016bb3be..2a8bbdeb2 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,7 +8,7 @@ import comfy.latent_formats import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io import logging @@ -1291,12 +1291,25 @@ class Wan22ImageToVideoLatent(io.ComfyNode): from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + @classmethod def define_schema(cls): return io.Schema( node_id="WanInfiniteTalkToVideo", category="conditioning/video_models", inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), io.Model.Input("model"), io.ModelPatch.Input("model_patch"), io.Conditioning.Input("positive"), @@ -1308,9 +1321,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode): io.ClipVisionOutput.Input("clip_vision_output", optional=True), io.Image.Input("start_image", optional=True), io.AudioEncoderOutput.Input("audio_encoder_output_1"), - io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), - io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), - io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), io.Image.Input("previous_frames", optional=True), @@ -1325,7 +1335,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode): ) @classmethod - def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: