mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
2cc7bafb52
@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_i):
|
if i < len(control_i):
|
||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img[:, :add.shape[1]] += add
|
||||||
|
|
||||||
if img.dtype == torch.float16:
|
if img.dtype == torch.float16:
|
||||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@ -189,7 +189,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
@ -459,7 +459,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if i < len(control_i):
|
if i < len(control_i):
|
||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|||||||
@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
|
|||||||
audio_emb = None
|
audio_emb = None
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
|
bs, _, time, height, width = x.shape
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
x = x + self.cond_encoder(control_video)
|
x = x + self.cond_encoder(control_video)
|
||||||
@ -1272,11 +1273,12 @@ class WanModel_S2V(WanModel):
|
|||||||
if reference_latent is not None:
|
if reference_latent is not None:
|
||||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
ref = ref.flatten(2).transpose(1, 2)
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||||
ref = ref + cond_mask_weight[1]
|
ref = ref + cond_mask_weight[1]
|
||||||
x = torch.cat([x, ref], dim=1)
|
x = torch.cat([x, ref], dim=1)
|
||||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del ref, freqs_ref
|
||||||
|
|
||||||
if reference_motion is not None:
|
if reference_motion is not None:
|
||||||
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||||
@ -1286,6 +1288,7 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
t = torch.repeat_interleave(t, 2, dim=1)
|
t = torch.repeat_interleave(t, 2, dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del motion_encoded, freqs_motion
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
@ -1296,7 +1299,6 @@ class WanModel_S2V(WanModel):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
|||||||
@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
self.memory_usage_factor_conds = ()
|
self.memory_usage_factor_conds = ()
|
||||||
|
self.memory_usage_shape_process = {}
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes = [input_shape]
|
input_shapes = [input_shape]
|
||||||
for c in self.memory_usage_factor_conds:
|
for c in self.memory_usage_factor_conds:
|
||||||
shape = cond_shapes.get(c, None)
|
shape = cond_shapes.get(c, None)
|
||||||
if shape is not None and len(shape) > 0:
|
if shape is not None:
|
||||||
input_shapes += shape
|
if c in self.memory_usage_shape_process:
|
||||||
|
out = []
|
||||||
|
for s in shape:
|
||||||
|
out.append(self.memory_usage_shape_process[c](s))
|
||||||
|
shape = out
|
||||||
|
|
||||||
|
if len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21):
|
|||||||
class WAN22_S2V(WAN21):
|
class WAN22_S2V(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||||
|
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21):
|
|||||||
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = reference_motion.shape
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22(BaseModel):
|
class WAN22(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
2
comfy/samplers.py
Normal file → Executable file
2
comfy/samplers.py
Normal file → Executable file
@ -729,7 +729,7 @@ class Sampler:
|
|||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||||
|
|
||||||
|
|||||||
@ -700,7 +700,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Flux
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
memory_usage_factor = 2.8
|
memory_usage_factor = 3.1 # TODO: debug why flux mem usage is so weird on windows.
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
|||||||
@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LokrDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -105,6 +105,38 @@ class LatentInterpolate:
|
|||||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentConcat:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced"
|
||||||
|
|
||||||
|
def op(self, samples1, samples2, dim):
|
||||||
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
|
s1 = samples1["samples"]
|
||||||
|
s2 = samples2["samples"]
|
||||||
|
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
|
||||||
|
|
||||||
|
if "-" in dim:
|
||||||
|
c = (s2, s1)
|
||||||
|
else:
|
||||||
|
c = (s1, s2)
|
||||||
|
|
||||||
|
if "x" in dim:
|
||||||
|
dim = -1
|
||||||
|
elif "y" in dim:
|
||||||
|
dim = -2
|
||||||
|
elif "t" in dim:
|
||||||
|
dim = -3
|
||||||
|
|
||||||
|
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
class LatentBatch:
|
class LatentBatch:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -279,6 +311,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
"LatentMultiply": LatentMultiply,
|
"LatentMultiply": LatentMultiply,
|
||||||
"LatentInterpolate": LatentInterpolate,
|
"LatentInterpolate": LatentInterpolate,
|
||||||
|
"LatentConcat": LatentConcat,
|
||||||
"LatentBatch": LatentBatch,
|
"LatentBatch": LatentBatch,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
||||||
"LatentApplyOperation": LatentApplyOperation,
|
"LatentApplyOperation": LatentApplyOperation,
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class DiffSynthCnetPatch:
|
|||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
|
|
||||||
def encode_latent_cond(self, image):
|
def encode_latent_cond(self, image):
|
||||||
latent_image = self.vae.encode(image)
|
latent_image = self.vae.encode(image)
|
||||||
@ -106,14 +107,15 @@ class DiffSynthCnetPatch:
|
|||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
img = kwargs.get("img")
|
img = kwargs.get("img")
|
||||||
block_index = kwargs.get("block_index")
|
block_index = kwargs.get("block_index")
|
||||||
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||||
kwargs['img'] = img
|
kwargs['img'] = img
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|||||||
@ -877,6 +877,68 @@ def get_audio_embed_bucket_fps(audio_embed, fps=16, batch_frames=81, m=0, video_
|
|||||||
return batch_audio_eb, min_batch_num
|
return batch_audio_eb, min_batch_num
|
||||||
|
|
||||||
|
|
||||||
|
def wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=0, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None, ref_motion_latent=None):
|
||||||
|
latent_t = ((length - 1) // 4) + 1
|
||||||
|
if audio_encoder_output is not None:
|
||||||
|
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
|
||||||
|
video_rate = 30
|
||||||
|
fps = 16
|
||||||
|
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
||||||
|
batch_frames = latent_t * 4
|
||||||
|
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=batch_frames, m=0, video_rate=video_rate)
|
||||||
|
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
|
||||||
|
if len(audio_embed_bucket.shape) == 3:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
||||||
|
elif len(audio_embed_bucket.shape) == 4:
|
||||||
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
audio_embed_bucket = audio_embed_bucket[:, :, :, frame_offset:frame_offset + batch_frames]
|
||||||
|
if audio_embed_bucket.shape[3] > 0:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
|
||||||
|
frame_offset += batch_frames
|
||||||
|
|
||||||
|
if ref_image is not None:
|
||||||
|
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
|
||||||
|
if ref_motion is not None:
|
||||||
|
if ref_motion.shape[0] > 73:
|
||||||
|
ref_motion = ref_motion[-73:]
|
||||||
|
|
||||||
|
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|
||||||
|
if ref_motion.shape[0] < 73:
|
||||||
|
r = torch.ones([73, height, width, 3]) * 0.5
|
||||||
|
r[-ref_motion.shape[0]:] = ref_motion
|
||||||
|
ref_motion = r
|
||||||
|
|
||||||
|
ref_motion_latent = vae.encode(ref_motion[:, :, :, :3])
|
||||||
|
|
||||||
|
if ref_motion_latent is not None:
|
||||||
|
ref_motion_latent = ref_motion_latent[:, :, -19:]
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion_latent})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion_latent})
|
||||||
|
|
||||||
|
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
control_video = vae.encode(control_video[:, :, :, :3])
|
||||||
|
control_video_out[:, :, :control_video.shape[2]] = control_video
|
||||||
|
|
||||||
|
# TODO: check if zero is better than none if none provided
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return positive, negative, out_latent, frame_offset
|
||||||
|
|
||||||
|
|
||||||
class WanSoundImageToVideo(io.ComfyNode):
|
class WanSoundImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -906,57 +968,44 @@ class WanSoundImageToVideo(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None, control_video=None, ref_motion=None) -> io.NodeOutput:
|
||||||
latent_t = ((length - 1) // 4) + 1
|
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
|
||||||
if audio_encoder_output is not None:
|
control_video=control_video, ref_motion=ref_motion)
|
||||||
feat = torch.cat(audio_encoder_output["encoded_audio_all_layers"])
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
video_rate = 30
|
|
||||||
fps = 16
|
|
||||||
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
|
||||||
audio_embed_bucket, num_repeat = get_audio_embed_bucket_fps(feat, fps=fps, batch_frames=latent_t * 4, m=0, video_rate=video_rate)
|
|
||||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
|
|
||||||
if len(audio_embed_bucket.shape) == 3:
|
|
||||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
|
||||||
elif len(audio_embed_bucket.shape) == 4:
|
|
||||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket})
|
|
||||||
|
|
||||||
if ref_image is not None:
|
class WanSoundImageToVideoExtend(io.ComfyNode):
|
||||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
@classmethod
|
||||||
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
def define_schema(cls):
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
return io.Schema(
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
node_id="WanSoundImageToVideoExtend",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("length", default=77, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Latent.Input("video_latent"),
|
||||||
|
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||||
|
io.Image.Input("ref_image", optional=True),
|
||||||
|
io.Image.Input("control_video", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
if ref_motion is not None:
|
@classmethod
|
||||||
if ref_motion.shape[0] > 73:
|
def execute(cls, positive, negative, vae, length, video_latent, ref_image=None, audio_encoder_output=None, control_video=None) -> io.NodeOutput:
|
||||||
ref_motion = ref_motion[-73:]
|
video_latent = video_latent["samples"]
|
||||||
|
width = video_latent.shape[-1] * 8
|
||||||
ref_motion = comfy.utils.common_upscale(ref_motion.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
height = video_latent.shape[-2] * 8
|
||||||
|
batch_size = video_latent.shape[0]
|
||||||
if ref_motion.shape[0] < 73:
|
frame_offset = video_latent.shape[-3] * 4
|
||||||
r = torch.ones([73, height, width, 3]) * 0.5
|
positive, negative, out_latent, frame_offset = wan_sound_to_video(positive, negative, vae, width, height, length, batch_size, frame_offset=frame_offset, ref_image=ref_image, audio_encoder_output=audio_encoder_output,
|
||||||
r[-ref_motion.shape[0]:] = ref_motion
|
control_video=control_video, ref_motion=None, ref_motion_latent=video_latent)
|
||||||
ref_motion = r
|
|
||||||
|
|
||||||
ref_motion = vae.encode(ref_motion[:, :, :, :3])
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"reference_motion": ref_motion})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"reference_motion": ref_motion})
|
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
|
||||||
|
|
||||||
control_video_out = comfy.latent_formats.Wan21().process_out(torch.zeros_like(latent))
|
|
||||||
if control_video is not None:
|
|
||||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
|
||||||
control_video = vae.encode(control_video[:, :, :, :3])
|
|
||||||
control_video_out[:, :, :control_video.shape[2]] = control_video
|
|
||||||
|
|
||||||
# TODO: check if zero is better than none if none provided
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"control_video": control_video_out})
|
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"control_video": control_video_out})
|
|
||||||
|
|
||||||
out_latent = {}
|
|
||||||
out_latent["samples"] = latent
|
|
||||||
return io.NodeOutput(positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
@ -1019,6 +1068,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanCameraImageToVideo,
|
WanCameraImageToVideo,
|
||||||
WanPhantomSubjectToVideo,
|
WanPhantomSubjectToVideo,
|
||||||
WanSoundImageToVideo,
|
WanSoundImageToVideo,
|
||||||
|
WanSoundImageToVideoExtend,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.52"
|
__version__ = "0.3.54"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.52"
|
version = "0.3.54"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user