mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
fixed the syncform logic + condition-related fixes
the trimming fn needs an update because of the over-trimming
This commit is contained in:
parent
95d2aae264
commit
220c65dc5f
@ -635,9 +635,6 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
def find_period_by_first_row(mat):
|
def find_period_by_first_row(mat):
|
||||||
|
|
||||||
L, _ = mat.shape
|
L, _ = mat.shape
|
||||||
@ -650,21 +647,7 @@ def find_period_by_first_row(mat):
|
|||||||
if not candidate_positions:
|
if not candidate_positions:
|
||||||
return L
|
return L
|
||||||
|
|
||||||
for p in sorted(candidate_positions):
|
return len(mat[:candidate_positions[0]])
|
||||||
base = mat[:p]
|
|
||||||
reps = _ceil_div(L, p)
|
|
||||||
tiled = base.repeat(reps, 1)[:L]
|
|
||||||
if torch.equal(tiled, mat):
|
|
||||||
return p
|
|
||||||
|
|
||||||
for p in range(1, L + 1):
|
|
||||||
base = mat[:p]
|
|
||||||
reps = _ceil_div(L, p)
|
|
||||||
tiled = base.repeat(reps, 1)[:L]
|
|
||||||
if torch.equal(tiled, mat):
|
|
||||||
return p
|
|
||||||
|
|
||||||
return L
|
|
||||||
|
|
||||||
def trim_repeats(expanded):
|
def trim_repeats(expanded):
|
||||||
seq = expanded[0]
|
seq = expanded[0]
|
||||||
@ -675,6 +658,14 @@ def trim_repeats(expanded):
|
|||||||
|
|
||||||
return expanded[:, :p_len, :p_dim]
|
return expanded[:, :p_len, :p_dim]
|
||||||
|
|
||||||
|
def unlock_cpu_tensor(t, device=None):
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
base = t.as_subclass(torch.Tensor).detach().clone()
|
||||||
|
if device is not None:
|
||||||
|
base = base.to(device)
|
||||||
|
return base
|
||||||
|
return t
|
||||||
|
|
||||||
class HunyuanVideoFoley(nn.Module):
|
class HunyuanVideoFoley(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -860,7 +851,7 @@ class HunyuanVideoFoley(nn.Module):
|
|||||||
bs, _, ol = x.shape
|
bs, _, ol = x.shape
|
||||||
tl = ol // self.patch_size
|
tl = ol // self.patch_size
|
||||||
|
|
||||||
condition, uncondition = torch.chunk(context, 2)
|
uncondition, condition = torch.chunk(context, 2)
|
||||||
|
|
||||||
condition = condition.view(3, context.size(1) // 3, -1)
|
condition = condition.view(3, context.size(1) // 3, -1)
|
||||||
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||||
@ -872,7 +863,7 @@ class HunyuanVideoFoley(nn.Module):
|
|||||||
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||||
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||||
|
|
||||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||||
|
|
||||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||||
|
|
||||||
|
|||||||
@ -211,17 +211,16 @@ class FoleyVae(torch.nn.Module):
|
|||||||
return self.synchformer(x)
|
return self.synchformer(x)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.encode(x)
|
try:
|
||||||
|
return self.encode(x)
|
||||||
|
except:
|
||||||
|
x = x.to(next(self.parameters()).device)
|
||||||
|
return self.encode(x)
|
||||||
|
|
||||||
def video_encoding(self, video, step):
|
def video_encoding(self, video, step):
|
||||||
t, h, w, c = video.shape
|
|
||||||
|
|
||||||
if not isinstance(video, torch.Tensor):
|
|
||||||
video = torch.from_numpy(video)
|
|
||||||
|
|
||||||
video = video.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
||||||
|
|
||||||
|
t, c, h, w = video.shape
|
||||||
seg_len = 16
|
seg_len = 16
|
||||||
t = video.size(0)
|
t = video.size(0)
|
||||||
nseg = max(0, (t - seg_len) // step + 1)
|
nseg = max(0, (t - seg_len) // step + 1)
|
||||||
|
|||||||
@ -1166,6 +1166,7 @@ class MultiheadAttentionComfyv(nn.Module):
|
|||||||
|
|
||||||
def forward(self, src, attn_mask = None, key_padding_mask = None):
|
def forward(self, src, attn_mask = None, key_padding_mask = None):
|
||||||
|
|
||||||
|
self._q_proj, self._k_proj, self._v_proj = [t.to(src.device).to(src.dtype) for t in (self._q_proj, self._k_proj, self._v_proj)]
|
||||||
q = self._q_proj(src)
|
q = self._q_proj(src)
|
||||||
k = self._k_proj(src)
|
k = self._k_proj(src)
|
||||||
v = self._v_proj(src)
|
v = self._v_proj(src)
|
||||||
|
|||||||
@ -51,6 +51,15 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||||
|
|
||||||
|
if not isinstance(video, torch.Tensor):
|
||||||
|
video = torch.from_numpy(video)
|
||||||
|
|
||||||
|
t, *rest = video.shape
|
||||||
|
|
||||||
|
# channel last
|
||||||
|
if rest[-1] in (1, 3, 4) and rest[0] not in (1, 3, 4):
|
||||||
|
video = video.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
t, c, h, w = video.shape
|
t, c, h, w = video.shape
|
||||||
device = video.device
|
device = video.device
|
||||||
b = 1
|
b = 1
|
||||||
@ -77,14 +86,16 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
outputs = None
|
outputs = None
|
||||||
total = data.shape[0]
|
total = data.shape[0]
|
||||||
pbar = comfy.utils.ProgressBar(total/batch_size)
|
pbar = comfy.utils.ProgressBar(total/batch_size)
|
||||||
with torch.inference_mode():
|
model_dtype = next(model.parameters()).dtype
|
||||||
|
with torch.inference_mode():
|
||||||
for i in range(0, total, batch_size):
|
for i in range(0, total, batch_size):
|
||||||
chunk = data[i : i + batch_size].to(device, non_blocking = True)
|
chunk = data[i : i + batch_size].to(device, non_blocking = True)
|
||||||
|
chunk = chunk.to(model_dtype)
|
||||||
if hasattr(vae, "encode"):
|
if hasattr(vae, "encode"):
|
||||||
try:
|
try:
|
||||||
out = vae.encode(chunk)
|
out = vae.encode(chunk)
|
||||||
except:
|
except:
|
||||||
out = model(chunk.to(next(model.parameters()).device))
|
out = model(chunk)
|
||||||
else:
|
else:
|
||||||
out = vae.encode_image(chunk)
|
out = vae.encode_image(chunk)
|
||||||
out = out["image_embeds"]
|
out = out["image_embeds"]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user