Allow HuMo to work with embedded image for I2V

This commit is contained in:
ozbayb 2025-09-26 07:53:08 -06:00
parent ce4cb2389c
commit 710254affc
3 changed files with 10 additions and 4 deletions

View File

@ -1510,7 +1510,7 @@ class HumoWanModel(WanModel):
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=36, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
@ -1539,6 +1539,12 @@ class HumoWanModel(WanModel):
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
if reference_latent is not None:
if reference_latent.shape[1] < 36:
padding_needed = 36 - reference_latent.shape[1]
padding = torch.zeros(reference_latent.shape[0], padding_needed, *reference_latent.shape[2:],
device=reference_latent.device, dtype=reference_latent.dtype)
reference_latent = torch.cat([padding, reference_latent], dim=1) # pad at beginning like c_concat
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
@ -1548,7 +1554,7 @@ class HumoWanModel(WanModel):
# context
context = self.text_embedding(context)
context_img_len = None
context_img_len = 0
if audio_embed is not None:
if reference_latent is not None:

View File

@ -1227,7 +1227,7 @@ class WAN21_HuMo(WAN21):
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
if "c_concat" not in out: # 1.7B model
if "c_concat" not in out or "concat_latent_image" in kwargs: # 1.7B model OR I2V mode
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))

View File

@ -1080,7 +1080,7 @@ class WAN21_HuMo(WAN21_T2V):
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
out = model_base.WAN21_HuMo(self, image_to_video=True, device=device)
return out
class WAN22_S2V(WAN21_T2V):