From 220c65dc5f2162bac185e2998e95ba24a6bc362f Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 6 Oct 2025 23:38:47 +0300 Subject: [PATCH] fixed the syncform logic + condition-related fixes the trimming fn needs an update because of the over-trimming --- comfy/ldm/hunyuan_foley/model.py | 31 +++++++++++-------------------- comfy/ldm/hunyuan_foley/vae.py | 15 +++++++-------- comfy/ldm/modules/attention.py | 1 + comfy_extras/nodes_video.py | 15 +++++++++++++-- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index bce9b25fc..78168b476 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -635,9 +635,6 @@ class SingleStreamBlock(nn.Module): return x -def _ceil_div(a, b): - return (a + b - 1) // b - def find_period_by_first_row(mat): L, _ = mat.shape @@ -650,21 +647,7 @@ def find_period_by_first_row(mat): if not candidate_positions: return L - for p in sorted(candidate_positions): - base = mat[:p] - reps = _ceil_div(L, p) - tiled = base.repeat(reps, 1)[:L] - if torch.equal(tiled, mat): - return p - - for p in range(1, L + 1): - base = mat[:p] - reps = _ceil_div(L, p) - tiled = base.repeat(reps, 1)[:L] - if torch.equal(tiled, mat): - return p - - return L + return len(mat[:candidate_positions[0]]) def trim_repeats(expanded): seq = expanded[0] @@ -675,6 +658,14 @@ def trim_repeats(expanded): return expanded[:, :p_len, :p_dim] +def unlock_cpu_tensor(t, device=None): + if isinstance(t, torch.Tensor): + base = t.as_subclass(torch.Tensor).detach().clone() + if device is not None: + base = base.to(device) + return base + return t + class HunyuanVideoFoley(nn.Module): def __init__( self, @@ -860,7 +851,7 @@ class HunyuanVideoFoley(nn.Module): bs, _, ol = x.shape tl = ol // self.patch_size - condition, uncondition = torch.chunk(context, 2) + uncondition, condition = torch.chunk(context, 2) condition = condition.view(3, context.size(1) // 3, -1) uncondition = uncondition.view(3, context.size(1) // 3, -1) @@ -872,7 +863,7 @@ class HunyuanVideoFoley(nn.Module): uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] - uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] + uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index a26c1524d..58c9bfdb4 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -211,17 +211,16 @@ class FoleyVae(torch.nn.Module): return self.synchformer(x) def forward(self, x): - return self.encode(x) + try: + return self.encode(x) + except: + x = x.to(next(self.parameters()).device) + return self.encode(x) def video_encoding(self, video, step): - t, h, w, c = video.shape - - if not isinstance(video, torch.Tensor): - video = torch.from_numpy(video) - - video = video.permute(0, 3, 1, 2) - video = torch.stack([self.syncformer_preprocess(t) for t in video]) + + t, c, h, w = video.shape seg_len = 16 t = video.size(0) nseg = max(0, (t - seg_len) // step + 1) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d8a2be67b..eed49269b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1166,6 +1166,7 @@ class MultiheadAttentionComfyv(nn.Module): def forward(self, src, attn_mask = None, key_padding_mask = None): + self._q_proj, self._k_proj, self._v_proj = [t.to(src.device).to(src.dtype) for t in (self._q_proj, self._k_proj, self._v_proj)] q = self._q_proj(src) k = self._k_proj(src) v = self._v_proj(src) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index cae7e7352..56a1457a1 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -51,6 +51,15 @@ class EncodeVideo(io.ComfyNode): @classmethod def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): + if not isinstance(video, torch.Tensor): + video = torch.from_numpy(video) + + t, *rest = video.shape + + # channel last + if rest[-1] in (1, 3, 4) and rest[0] not in (1, 3, 4): + video = video.permute(0, 3, 1, 2) + t, c, h, w = video.shape device = video.device b = 1 @@ -77,14 +86,16 @@ class EncodeVideo(io.ComfyNode): outputs = None total = data.shape[0] pbar = comfy.utils.ProgressBar(total/batch_size) - with torch.inference_mode(): + model_dtype = next(model.parameters()).dtype + with torch.inference_mode(): for i in range(0, total, batch_size): chunk = data[i : i + batch_size].to(device, non_blocking = True) + chunk = chunk.to(model_dtype) if hasattr(vae, "encode"): try: out = vae.encode(chunk) except: - out = model(chunk.to(next(model.parameters()).device)) + out = model(chunk) else: out = vae.encode_image(chunk) out = out["image_embeds"]