diff --git a/comfy/ldm/hunyuan_foley/syncformer.py b/comfy/ldm/hunyuan_foley/syncformer.py index 089f80c92..8340bfedf 100644 --- a/comfy/ldm/hunyuan_foley/syncformer.py +++ b/comfy/ldm/hunyuan_foley/syncformer.py @@ -65,39 +65,32 @@ class PatchEmbed3D(nn.Module): x = x.flatten(2).transpose(1, 2) return x -def qkv_attn(q, k, v, heads): - bh, seq_q, dim_head = q.shape - b = bh // heads - - # (b*heads, seq, dim) -> (b, heads, seq, dim) - q2 = q.view(b, heads, seq_q, dim_head) - k2 = k.view(b, heads, k.shape[1], dim_head) - v2 = v.view(b, heads, v.shape[1], dim_head) - - out = optimized_attention(q2, k2, v2, heads=heads, skip_reshape=True) - - out = out.permute(0, 2, 1, 3).contiguous().view(b * heads, seq_q, dim_head) - +def qkv_attn(q, k, v): + sim = torch.einsum("b i d, b j d -> b i j", q, k) + attn = sim.softmax(dim=-1) + out = torch.einsum("b i j, b j d -> b i d", attn, v) return out - class DividedAttention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=None): + def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=nn, **kwargs): super().__init__() self.num_heads = num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + head_dim = dim // num_heads + self.scale = head_dim**-0.5 def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): h = self.num_heads q, k, v = self.qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q *= self.scale (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) - cls_out = qkv_attn(cls_q, k, v, self.num_heads) + cls_out = qkv_attn(cls_q, k, v) q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) @@ -107,7 +100,7 @@ class DividedAttention(nn.Module): k_ = torch.cat((cls_k, k_), dim=1) v_ = torch.cat((cls_v, v_), dim=1) - out = qkv_attn(q_, k_, v_, self.num_heads) + out = qkv_attn(q_, k_, v_) out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) out = torch.cat((cls_out, out), dim=1) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index 8e29a05ae..a26c1524d 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -195,7 +195,7 @@ class FoleyVae(torch.nn.Module): def __init__(self): super().__init__() self.dac = DAC() - self.syncformer = Synchformer(None, None, operations = ops) + self.synchformer = Synchformer(None, None, operations = ops) self.syncformer_preprocess = v2.Compose( [ v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), @@ -208,9 +208,12 @@ class FoleyVae(torch.nn.Module): def decode(self, x, vae_options = {}): return self.dac.decode(x) def encode(self, x): - return self.syncformer(x) + return self.synchformer(x) + + def forward(self, x): + return self.encode(x) - def video_encoding(self, video, step: int): + def video_encoding(self, video, step): t, h, w, c = video.shape if not isinstance(video, torch.Tensor): @@ -218,10 +221,12 @@ class FoleyVae(torch.nn.Module): video = video.permute(0, 3, 1, 2) - video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0) + video = torch.stack([self.syncformer_preprocess(t) for t in video]) seg_len = 16 t = video.size(0) nseg = max(0, (t - seg_len) // step + 1) + + video = video.contiguous() stride_t, stride_c, stride_h, stride_w = video.stride() # no copies @@ -229,6 +234,7 @@ class FoleyVae(torch.nn.Module): size=(nseg, seg_len, c, h, w), stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w), ) + data = data.unsqueeze(0) # b data = rearrange(data, "b s t c h w -> (b s) 1 t c h w") - return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0)) + return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=1) diff --git a/comfy/sd.py b/comfy/sd.py index ec44d4288..28fc45e41 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -504,14 +504,15 @@ class VAE: self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] # Hunyuan Foley - elif "syncformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd: + elif "synchformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd: self.latent_dim = 128 self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae() # TODO encode_layers = 25 decode_layers = 4 - self.memory_used_encode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * encode_layers - self.memory_used_decode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * decode_layers + self.not_video = True + self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers + self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 45277c9c4..cae7e7352 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -64,7 +64,6 @@ class EncodeVideo(io.ComfyNode): vae = vae if vae is not None else clip_vision # should be the offload device - video = video.cpu() if hasattr(model, "video_encoding"): data, num_segments, output_fn = model.video_encoding(video, step_size) batch_size = b * num_segments @@ -82,7 +81,10 @@ class EncodeVideo(io.ComfyNode): for i in range(0, total, batch_size): chunk = data[i : i + batch_size].to(device, non_blocking = True) if hasattr(vae, "encode"): - out = vae.encode(chunk) + try: + out = vae.encode(chunk) + except: + out = model(chunk.to(next(model.parameters()).device)) else: out = vae.encode_image(chunk) out = out["image_embeds"]