fixed the syncform logic + condition-related fixes

the trimming fn needs an update because of the over-trimming
This commit is contained in:
Yousef Rafat 2025-10-06 23:38:47 +03:00
parent 95d2aae264
commit 220c65dc5f
4 changed files with 32 additions and 30 deletions

View File

@ -635,9 +635,6 @@ class SingleStreamBlock(nn.Module):
return x
def _ceil_div(a, b):
return (a + b - 1) // b
def find_period_by_first_row(mat):
L, _ = mat.shape
@ -650,21 +647,7 @@ def find_period_by_first_row(mat):
if not candidate_positions:
return L
for p in sorted(candidate_positions):
base = mat[:p]
reps = _ceil_div(L, p)
tiled = base.repeat(reps, 1)[:L]
if torch.equal(tiled, mat):
return p
for p in range(1, L + 1):
base = mat[:p]
reps = _ceil_div(L, p)
tiled = base.repeat(reps, 1)[:L]
if torch.equal(tiled, mat):
return p
return L
return len(mat[:candidate_positions[0]])
def trim_repeats(expanded):
seq = expanded[0]
@ -675,6 +658,14 @@ def trim_repeats(expanded):
return expanded[:, :p_len, :p_dim]
def unlock_cpu_tensor(t, device=None):
if isinstance(t, torch.Tensor):
base = t.as_subclass(torch.Tensor).detach().clone()
if device is not None:
base = base.to(device)
return base
return t
class HunyuanVideoFoley(nn.Module):
def __init__(
self,
@ -860,7 +851,7 @@ class HunyuanVideoFoley(nn.Module):
bs, _, ol = x.shape
tl = ol // self.patch_size
condition, uncondition = torch.chunk(context, 2)
uncondition, condition = torch.chunk(context, 2)
condition = condition.view(3, context.size(1) // 3, -1)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
@ -872,7 +863,7 @@ class HunyuanVideoFoley(nn.Module):
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])

View File

@ -211,17 +211,16 @@ class FoleyVae(torch.nn.Module):
return self.synchformer(x)
def forward(self, x):
return self.encode(x)
try:
return self.encode(x)
except:
x = x.to(next(self.parameters()).device)
return self.encode(x)
def video_encoding(self, video, step):
t, h, w, c = video.shape
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video)
video = video.permute(0, 3, 1, 2)
video = torch.stack([self.syncformer_preprocess(t) for t in video])
t, c, h, w = video.shape
seg_len = 16
t = video.size(0)
nseg = max(0, (t - seg_len) // step + 1)

View File

@ -1166,6 +1166,7 @@ class MultiheadAttentionComfyv(nn.Module):
def forward(self, src, attn_mask = None, key_padding_mask = None):
self._q_proj, self._k_proj, self._v_proj = [t.to(src.device).to(src.dtype) for t in (self._q_proj, self._k_proj, self._v_proj)]
q = self._q_proj(src)
k = self._k_proj(src)
v = self._v_proj(src)

View File

@ -51,6 +51,15 @@ class EncodeVideo(io.ComfyNode):
@classmethod
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video)
t, *rest = video.shape
# channel last
if rest[-1] in (1, 3, 4) and rest[0] not in (1, 3, 4):
video = video.permute(0, 3, 1, 2)
t, c, h, w = video.shape
device = video.device
b = 1
@ -77,14 +86,16 @@ class EncodeVideo(io.ComfyNode):
outputs = None
total = data.shape[0]
pbar = comfy.utils.ProgressBar(total/batch_size)
with torch.inference_mode():
model_dtype = next(model.parameters()).dtype
with torch.inference_mode():
for i in range(0, total, batch_size):
chunk = data[i : i + batch_size].to(device, non_blocking = True)
chunk = chunk.to(model_dtype)
if hasattr(vae, "encode"):
try:
out = vae.encode(chunk)
except:
out = model(chunk.to(next(model.parameters()).device))
out = model(chunk)
else:
out = vae.encode_image(chunk)
out = out["image_embeds"]