Update model.py

This commit is contained in:
kijai 2025-11-16 01:00:39 +02:00 committed by comfyanonymous
parent 24d1b6b88a
commit 7378bf6a27

View File

@ -367,10 +367,11 @@ class HunyuanVideo(nn.Module):
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
#todo vision_in
if self.cond_type_embedding is not None and vision_states is not None:
if vision_states is not None:
txt_vision_states = self.vision_in(vision_states)
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
#print("txt_vision_states shape:", txt_vision_states.shape)
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)