mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Update model.py
This commit is contained in:
parent
24d1b6b88a
commit
7378bf6a27
@ -367,10 +367,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||||
|
|
||||||
#todo vision_in
|
#todo vision_in
|
||||||
if self.cond_type_embedding is not None and vision_states is not None:
|
if vision_states is not None:
|
||||||
txt_vision_states = self.vision_in(vision_states)
|
txt_vision_states = self.vision_in(vision_states)
|
||||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
if self.cond_type_embedding is not None:
|
||||||
txt_vision_states = txt_vision_states + cond_emb
|
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||||
|
txt_vision_states = txt_vision_states + cond_emb
|
||||||
#print("txt_vision_states shape:", txt_vision_states.shape)
|
#print("txt_vision_states shape:", txt_vision_states.shape)
|
||||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user