diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index 78168b476..5f85c99f4 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -154,6 +154,7 @@ class ChannelLastConv1d(nn.Module): object.__setattr__(self, "_underlying", underlying) def forward(self, x: torch.Tensor) -> torch.Tensor: + self._underlying = self._underlying.to(x.dtype) x = self._underlying(x.permute(0, 2, 1)) return x.permute(0, 2, 1) @@ -219,6 +220,7 @@ class FinalLayer1D(nn.Module): def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift=shift, scale=scale) + self.linear = self.linear.to(x.dtype) x = self.linear(x) return x @@ -614,6 +616,7 @@ class SingleStreamBlock(nn.Module): modulation = self.modulation(cond) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa + x_norm1 = x_norm1.to(next(self.linear_qkv.parameters()).dtype) qkv = self.linear_qkv(x_norm1) q, k, v = self.rearrange(qkv).chunk(3, dim=-1) @@ -647,7 +650,11 @@ def find_period_by_first_row(mat): if not candidate_positions: return L - return len(mat[:candidate_positions[0]]) + for p in sorted(candidate_positions): + a, b = mat[p:], mat[:-p] + if torch.equal(a, b): + return p + return L def trim_repeats(expanded): seq = expanded[0] @@ -865,7 +872,14 @@ class HunyuanVideoFoley(nn.Module): uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] + diff = cond_pos.shape[1] - cond_neg.shape[1] + if cond_neg.shape[1] < cond_pos.shape[1]: + cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff)) + elif diff < 0: + cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff))) + clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) + clip_feat = clip_feat.view(2, -1, 768) if drop_visual is not None: clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) @@ -954,8 +968,7 @@ class HunyuanVideoFoley(nn.Module): audio = self.final_layer(audio, vec) audio = self.unpatchify1d(audio, tl) - uncond, cond = torch.chunk(2, audio) - return torch.cat([cond, uncond]) + return audio def unpatchify1d(self, x, l): c = self.unpatchify_channels diff --git a/comfy/ldm/hydit/posemb_layers.py b/comfy/ldm/hydit/posemb_layers.py index 0c2085405..893744b24 100644 --- a/comfy/ldm/hydit/posemb_layers.py +++ b/comfy/ldm/hydit/posemb_layers.py @@ -183,8 +183,10 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float if freq_scaling != 1.0: freqs *= freq_scaling - - t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + if not isinstance(pos, torch.Tensor): + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + else: + t = pos.to(freqs.device) freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] diff --git a/comfy/text_encoders/clap_model.py b/comfy/text_encoders/clap_model.py index e992a52b7..c7c57a071 100644 --- a/comfy/text_encoders/clap_model.py +++ b/comfy/text_encoders/clap_model.py @@ -361,4 +361,4 @@ class ClapTextEncoderModel(sd1_clip.SDClipModel): class ClapLargeTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=1, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=77, min_length=1, pad_token=1, tokenizer_data=tokenizer_data)