mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
syncformer fix + some fixes
This commit is contained in:
parent
4b6c08110d
commit
95d2aae264
@ -65,39 +65,32 @@ class PatchEmbed3D(nn.Module):
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
def qkv_attn(q, k, v, heads):
|
||||
bh, seq_q, dim_head = q.shape
|
||||
b = bh // heads
|
||||
|
||||
# (b*heads, seq, dim) -> (b, heads, seq, dim)
|
||||
q2 = q.view(b, heads, seq_q, dim_head)
|
||||
k2 = k.view(b, heads, k.shape[1], dim_head)
|
||||
v2 = v.view(b, heads, v.shape[1], dim_head)
|
||||
|
||||
out = optimized_attention(q2, k2, v2, heads=heads, skip_reshape=True)
|
||||
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(b * heads, seq_q, dim_head)
|
||||
|
||||
def qkv_attn(q, k, v):
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k)
|
||||
attn = sim.softmax(dim=-1)
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
return out
|
||||
|
||||
|
||||
class DividedAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=None):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=nn, **kwargs):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
|
||||
h = self.num_heads
|
||||
|
||||
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
q *= self.scale
|
||||
|
||||
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
||||
|
||||
cls_out = qkv_attn(cls_q, k, v, self.num_heads)
|
||||
cls_out = qkv_attn(cls_q, k, v)
|
||||
|
||||
q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_))
|
||||
|
||||
@ -107,7 +100,7 @@ class DividedAttention(nn.Module):
|
||||
k_ = torch.cat((cls_k, k_), dim=1)
|
||||
v_ = torch.cat((cls_v, v_), dim=1)
|
||||
|
||||
out = qkv_attn(q_, k_, v_, self.num_heads)
|
||||
out = qkv_attn(q_, k_, v_)
|
||||
out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims)
|
||||
|
||||
out = torch.cat((cls_out, out), dim=1)
|
||||
|
||||
@ -195,7 +195,7 @@ class FoleyVae(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dac = DAC()
|
||||
self.syncformer = Synchformer(None, None, operations = ops)
|
||||
self.synchformer = Synchformer(None, None, operations = ops)
|
||||
self.syncformer_preprocess = v2.Compose(
|
||||
[
|
||||
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
@ -208,9 +208,12 @@ class FoleyVae(torch.nn.Module):
|
||||
def decode(self, x, vae_options = {}):
|
||||
return self.dac.decode(x)
|
||||
def encode(self, x):
|
||||
return self.syncformer(x)
|
||||
return self.synchformer(x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encode(x)
|
||||
|
||||
def video_encoding(self, video, step: int):
|
||||
def video_encoding(self, video, step):
|
||||
t, h, w, c = video.shape
|
||||
|
||||
if not isinstance(video, torch.Tensor):
|
||||
@ -218,10 +221,12 @@ class FoleyVae(torch.nn.Module):
|
||||
|
||||
video = video.permute(0, 3, 1, 2)
|
||||
|
||||
video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0)
|
||||
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
||||
seg_len = 16
|
||||
t = video.size(0)
|
||||
nseg = max(0, (t - seg_len) // step + 1)
|
||||
|
||||
video = video.contiguous()
|
||||
stride_t, stride_c, stride_h, stride_w = video.stride()
|
||||
|
||||
# no copies
|
||||
@ -229,6 +234,7 @@ class FoleyVae(torch.nn.Module):
|
||||
size=(nseg, seg_len, c, h, w),
|
||||
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
|
||||
)
|
||||
data = data.unsqueeze(0) # b
|
||||
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
||||
|
||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))
|
||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=1)
|
||||
|
||||
@ -504,14 +504,15 @@ class VAE:
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
# Hunyuan Foley
|
||||
elif "syncformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd:
|
||||
elif "synchformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd:
|
||||
self.latent_dim = 128
|
||||
self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae()
|
||||
# TODO
|
||||
encode_layers = 25
|
||||
decode_layers = 4
|
||||
self.memory_used_encode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * encode_layers
|
||||
self.memory_used_decode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * decode_layers
|
||||
self.not_video = True
|
||||
self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers
|
||||
self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
|
||||
@ -64,7 +64,6 @@ class EncodeVideo(io.ComfyNode):
|
||||
vae = vae if vae is not None else clip_vision
|
||||
|
||||
# should be the offload device
|
||||
video = video.cpu()
|
||||
if hasattr(model, "video_encoding"):
|
||||
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||
batch_size = b * num_segments
|
||||
@ -82,7 +81,10 @@ class EncodeVideo(io.ComfyNode):
|
||||
for i in range(0, total, batch_size):
|
||||
chunk = data[i : i + batch_size].to(device, non_blocking = True)
|
||||
if hasattr(vae, "encode"):
|
||||
out = vae.encode(chunk)
|
||||
try:
|
||||
out = vae.encode(chunk)
|
||||
except:
|
||||
out = model(chunk.to(next(model.parameters()).device))
|
||||
else:
|
||||
out = vae.encode_image(chunk)
|
||||
out = out["image_embeds"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user