mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
fixes to make the model work
This commit is contained in:
parent
220c65dc5f
commit
4c782e3395
@ -154,6 +154,7 @@ class ChannelLastConv1d(nn.Module):
|
||||
object.__setattr__(self, "_underlying", underlying)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self._underlying = self._underlying.to(x.dtype)
|
||||
x = self._underlying(x.permute(0, 2, 1))
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
@ -219,6 +220,7 @@ class FinalLayer1D(nn.Module):
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
self.linear = self.linear.to(x.dtype)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
@ -614,6 +616,7 @@ class SingleStreamBlock(nn.Module):
|
||||
modulation = self.modulation(cond)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
|
||||
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
|
||||
x_norm1 = x_norm1.to(next(self.linear_qkv.parameters()).dtype)
|
||||
|
||||
qkv = self.linear_qkv(x_norm1)
|
||||
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
|
||||
@ -647,7 +650,11 @@ def find_period_by_first_row(mat):
|
||||
if not candidate_positions:
|
||||
return L
|
||||
|
||||
return len(mat[:candidate_positions[0]])
|
||||
for p in sorted(candidate_positions):
|
||||
a, b = mat[p:], mat[:-p]
|
||||
if torch.equal(a, b):
|
||||
return p
|
||||
return L
|
||||
|
||||
def trim_repeats(expanded):
|
||||
seq = expanded[0]
|
||||
@ -865,7 +872,14 @@ class HunyuanVideoFoley(nn.Module):
|
||||
|
||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
diff = cond_pos.shape[1] - cond_neg.shape[1]
|
||||
if cond_neg.shape[1] < cond_pos.shape[1]:
|
||||
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
|
||||
elif diff < 0:
|
||||
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
clip_feat = clip_feat.view(2, -1, 768)
|
||||
|
||||
if drop_visual is not None:
|
||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||
@ -954,8 +968,7 @@ class HunyuanVideoFoley(nn.Module):
|
||||
audio = self.final_layer(audio, vec)
|
||||
audio = self.unpatchify1d(audio, tl)
|
||||
|
||||
uncond, cond = torch.chunk(2, audio)
|
||||
return torch.cat([cond, uncond])
|
||||
return audio
|
||||
|
||||
def unpatchify1d(self, x, l):
|
||||
c = self.unpatchify_channels
|
||||
|
||||
@ -183,8 +183,10 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
||||
|
||||
if freq_scaling != 1.0:
|
||||
freqs *= freq_scaling
|
||||
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
if not isinstance(pos, torch.Tensor):
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
else:
|
||||
t = pos.to(freqs.device)
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
|
||||
@ -361,4 +361,4 @@ class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
||||
class ClapLargeTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=1, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=77, min_length=1, pad_token=1, tokenizer_data=tokenizer_data)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user