Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-08-29 07:36:57 +09:00
commit 2cc7bafb52
13 changed files with 188 additions and 62 deletions

View File

@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
return x return x
@torch.no_grad()
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE.""" """DPM-Solver++(3M) SDE."""
@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
if len(sigmas) <= 1:
return x
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:

View File

@ -158,7 +158,7 @@ class Flux(nn.Module):
if i < len(control_i): if i < len(control_i):
add = control_i[i] add = control_i[i]
if add is not None: if add is not None:
img += add img[:, :add.shape[1]] += add
if img.dtype == torch.float16: if img.dtype == torch.float16:
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
@ -189,7 +189,7 @@ class Flux(nn.Module):
if i < len(control_o): if i < len(control_o):
add = control_o[i] add = control_o[i]
if add is not None: if add is not None:
img[:, txt.shape[1] :, ...] += add img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]

View File

@ -459,7 +459,7 @@ class QwenImageTransformer2DModel(nn.Module):
if i < len(control_i): if i < len(control_i):
add = control_i[i] add = control_i[i]
if add is not None: if add is not None:
hidden_states += add hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)

View File

@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
audio_emb = None audio_emb = None
# embeddings # embeddings
bs, _, time, height, width = x.shape
x = self.patch_embedding(x.float()).to(x.dtype) x = self.patch_embedding(x.float()).to(x.dtype)
if control_video is not None: if control_video is not None:
x = x + self.cond_encoder(control_video) x = x + self.cond_encoder(control_video)
@ -1272,11 +1273,12 @@ class WanModel_S2V(WanModel):
if reference_latent is not None: if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype) ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2) ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype) freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
ref = ref + cond_mask_weight[1] ref = ref + cond_mask_weight[1]
x = torch.cat([x, ref], dim=1) x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1) freqs = torch.cat([freqs, freqs_ref], dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1) t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
del ref, freqs_ref
if reference_motion is not None: if reference_motion is not None:
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self) motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
@ -1286,6 +1288,7 @@ class WanModel_S2V(WanModel):
t = torch.repeat_interleave(t, 2, dim=1) t = torch.repeat_interleave(t, 2, dim=1)
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1) t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
del motion_encoded, freqs_motion
# time embeddings # time embeddings
e = self.time_embedding( e = self.time_embedding(
@ -1296,7 +1299,6 @@ class WanModel_S2V(WanModel):
# context # context
context = self.text_embedding(context) context = self.text_embedding(context)
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):

View File

@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = () self.memory_usage_factor_conds = ()
self.memory_usage_shape_process = {}
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -350,7 +351,14 @@ class BaseModel(torch.nn.Module):
input_shapes = [input_shape] input_shapes = [input_shape]
for c in self.memory_usage_factor_conds: for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None) shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0: if shape is not None:
if c in self.memory_usage_shape_process:
out = []
for s in shape:
out.append(self.memory_usage_shape_process[c](s))
shape = out
if len(shape) > 0:
input_shapes += shape input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21):
class WAN22_S2V(WAN21): class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21):
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video)) out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
return out return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
reference_motion = kwargs.get("reference_motion", None)
if reference_motion is not None:
out['reference_motion'] = reference_motion.shape
return out
class WAN22(BaseModel): class WAN22(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

2
comfy/samplers.py Normal file → Executable file
View File

@ -729,7 +729,7 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]

View File

@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
memory_usage_factor = 2.8 memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

View File

@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
(mat1, mat2, alpha, None, None, None, None, None, None) (mat1, mat2, alpha, None, None, None, None, None, None)
) )
def to_train(self):
return LokrDiff(self.weights)
@classmethod @classmethod
def load( def load(
cls, cls,

View File

@ -105,6 +105,38 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return (samples_out,)
class LatentConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
if "-" in dim:
c = (s2, s1)
else:
c = (s1, s2)
if "x" in dim:
dim = -1
elif "y" in dim:
dim = -2
elif "t" in dim:
dim = -3
samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,)
class LatentBatch: class LatentBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -279,6 +311,7 @@ NODE_CLASS_MAPPINGS = {
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentConcat": LatentConcat,
"LatentBatch": LatentBatch, "LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior, "LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentApplyOperation": LatentApplyOperation, "LatentApplyOperation": LatentApplyOperation,

View File

@ -89,6 +89,7 @@ class DiffSynthCnetPatch:
self.strength = strength self.strength = strength
self.mask = mask self.mask = mask
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
self.encoded_image_size = (image.shape[1], image.shape[2])
def encode_latent_cond(self, image): def encode_latent_cond(self, image):
latent_image = self.vae.encode(image) latent_image = self.vae.encode(image)
@ -106,14 +107,15 @@ class DiffSynthCnetPatch:
x = kwargs.get("x") x = kwargs.get("x")
img = kwargs.get("img") img = kwargs.get("img")
block_index = kwargs.get("block_index") block_index = kwargs.get("block_index")
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
spacial_compression = self.vae.spacial_compression_encode() spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models) comfy.model_management.load_models_gpu(loaded_models)
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img kwargs['img'] = img
return kwargs return kwargs

View File

@ -877,6 +877,68 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_
return batch_audio_eb, min_batch_num return batch_audio_eb, min_batch_num
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
latent_t = ((length - 1) // 4) + 1
if audio_encoder_output is not None:
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
video_rate = 30
fps = 16
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
batch_frames = latent_t * 4
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
if audio_embed_bucket.shape[3] > 0:
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
frame_offset += batch_frames
if ref_image is not None:
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(ref_image[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if ref_motion is not None:
if ref_motion.shape[0] > 73:
ref_motion = ref_motion[-73:]
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if ref_motion.shape[0] < 73:
r = torch.ones([73, height, width, 3]) * 0.5
r[-ref_motion.shape[0]:] = ref_motion
ref_motion = r
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
if ref_motion_latent is not None:
ref_motion_latent = ref_motion_latent[:, :, -19:]
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
control_video = vae.encode(control_video[:, :, :, :3])
control_video_out[:, :, :control_video.shape[2]] = control_video
# TODO: check if zero is better than none if none provided
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
out_latent = {}
out_latent["samples"] = latent
return positive, negative, out_latent, frame_offset
class WanSoundImageToVideo(io.ComfyNode): class WanSoundImageToVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -906,57 +968,44 @@ class WanSoundImageToVideo(io.ComfyNode):
@classmethod @classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput: def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
latent_t = ((length - 1) // 4) + 1 positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
if audio_encoder_output is not None: control_video=control_video, ref_motion=ref_motion)
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"]) return io.NodeOutput(positive, negative, out_latent)
video_rate = 30
fps = 16
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket})
if ref_image is not None: class WanSoundImageToVideoExtend(io.ComfyNode):
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @classmethod
ref_latent = vae.encode(ref_image[:, :, :, :3]) def define_schema(cls):
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) return io.Schema(
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) node_id="WanSoundImageToVideoExtend",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Latent.Input("video_latent"),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
io.Image.Input("control_video", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
if ref_motion is not None: @classmethod
if ref_motion.shape[0] > 73: def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
ref_motion = ref_motion[-73:] video_latent = video_latent["samples"]
width = video_latent.shape[-1] * 8
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) height = video_latent.shape[-2] * 8
batch_size = video_latent.shape[0]
if ref_motion.shape[0] < 73: frame_offset = video_latent.shape[-3] * 4
r = torch.ones([73, height, width, 3]) * 0.5 positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
r[-ref_motion.shape[0]:] = ref_motion control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
ref_motion = r
ref_motion = vae.encode(ref_motion[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion})
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion})
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
control_video = vae.encode(control_video[:, :, :, :3])
control_video_out[:, :, :control_video.shape[2]] = control_video
# TODO: check if zero is better than none if none provided
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent) return io.NodeOutput(positive, negative, out_latent)
@ -1019,6 +1068,7 @@ class WanExtension(ComfyExtension):
WanCameraImageToVideo, WanCameraImageToVideo,
WanPhantomSubjectToVideo, WanPhantomSubjectToVideo,
WanSoundImageToVideo, WanSoundImageToVideo,
WanSoundImageToVideoExtend,
Wan22ImageToVideoLatent, Wan22ImageToVideoLatent,
] ]

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.52" __version__ = "0.3.54"

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.52" version = "0.3.54"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"