mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
init
This commit is contained in:
parent
fee1e57ea9
commit
12824eac0d
356
comfy/clap_model.py
Normal file
356
comfy/clap_model.py
Normal file
@ -0,0 +1,356 @@
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.text_encoders.bert import (
|
||||
BertIntermediate as ClapTextIntermediate,
|
||||
BertOutput as ClapTextOutput,
|
||||
)
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
import os
|
||||
from comfy import sd1_clip
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
def apply_chunking_to_forward(
|
||||
forward_fn: Callable[..., torch.Tensor],
|
||||
chunk_size: int,
|
||||
chunk_dim: int,
|
||||
*input_tensors,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if chunk_size > 0:
|
||||
# chunk into tuples and apply forward_fn to each element in the tuple
|
||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
||||
|
||||
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
||||
return torch.cat(output_chunks, dim=chunk_dim)
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
||||
# from modeling_roberta
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
class ClapProjectionLayer(nn.Module):
|
||||
def __init__(self, hidden_size, projection_dim, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.linear1 = operations.Linear(hidden_size, projection_dim, device=device, dtype=dtype)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.linear2 = operations.Linear(projection_dim, projection_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.linear1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# same as RobertaEmbeddings
|
||||
class ClapTextEmbeddings(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.word_embeddings = operations.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, device=device, dtype=dtype)
|
||||
self.position_embeddings = operations.Embedding(max_position_embeddings, hidden_size, device=device, dtype=dtype)
|
||||
self.token_type_embeddings = operations.Embedding(type_vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(max_position_embeddings, device=device, dtype=dtype).expand((1, -1)), persistent=True
|
||||
)
|
||||
self.register_buffer(
|
||||
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=device), persistent=True
|
||||
)
|
||||
|
||||
# End copy
|
||||
self.padding_idx = pad_token_id
|
||||
self.position_embeddings = operations.Embedding(
|
||||
max_position_embeddings, hidden_size, padding_idx=self.padding_idx, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
sequence_length = input_shape[1]
|
||||
|
||||
position_ids = torch.arange(
|
||||
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
# same as AlignTextSelfAttention
|
||||
class ClapTextSelfAttention(nn.Module):
|
||||
def __init__(self, num_attention_heads, hidden_size, device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = int(hidden_size / num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
||||
self.key = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
||||
self.value = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
|
||||
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = [t.contiguous() for t in (query_states, key_states, value_states)]
|
||||
attn_output = optimized_attention(query_states, key_states, value_states, self.num_attention_heads, mask = attention_mask, skip_output_reshape=True, skip_reshape=True)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output.reshape(*input_shape, -1).contiguous()
|
||||
|
||||
class ClapTextSelfOutput(nn.Module):
|
||||
def __init__(self, hidden_size, layer_norm_eps, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
|
||||
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# same as AlignTextAttention
|
||||
class ClapTextAttention(nn.Module):
|
||||
def __init__(self, num_attention_heads, hidden_size, layer_norm_eps, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.self = ClapTextSelfAttention(num_attention_heads, hidden_size, device=device, dtype=dtype, operations = operations)
|
||||
self.output = ClapTextSelfOutput(hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
return self.output(self_outputs, hidden_states)
|
||||
|
||||
# same as AlignTextLayer
|
||||
class ClapTextLayer(nn.Module):
|
||||
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device, dtype, operations):
|
||||
super().__init__()
|
||||
# maybe we could allow chunking dynamically
|
||||
self.chunk_size_feed_forward = 0
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ClapTextAttention(num_attention_heads, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
||||
self.intermediate = ClapTextIntermediate(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
|
||||
self.output = ClapTextOutput(intermediate_size, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
# same as AlignTextEncoder
|
||||
class ClapTextEncoder(nn.Module):
|
||||
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList([ClapTextLayer(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for _, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
class ClapTextPooler(nn.Module):
|
||||
def __init__(self, hidden_size, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
class ClapTextModel(nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, num_hidden_layers,
|
||||
device, dtype, operations, add_pooling_layer=True):
|
||||
super().__init__()
|
||||
|
||||
self.embeddings = ClapTextEmbeddings(vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps,
|
||||
device=device, dtype=dtype, operations=operations,)
|
||||
self.encoder = ClapTextEncoder(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations,)
|
||||
self.pooler = ClapTextPooler(hidden_size, device=device, dtype=dtype, operations=operations,) if add_pooling_layer else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
return sequence_output, pooled_output
|
||||
|
||||
class ClapTextModelWithProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
intermediate_size: int = 3072,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
max_position_embeddings: int = 514,
|
||||
num_attention_heads: int = 12,
|
||||
num_hidden_layers: int = 12,
|
||||
projection_dim: int = 512,
|
||||
type_vocab_size: int = 1,
|
||||
vocab_size: int = 50265,
|
||||
pad_token_id: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.text_model = ClapTextModel(num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings,
|
||||
type_vocab_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations)
|
||||
self.text_projection = ClapProjectionLayer(hidden_size, projection_dim, device=device, dtype=dtype, operations=operations,)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
return text_embeds, text_outputs[0]
|
||||
|
||||
class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 1}, layer_norm_hidden_state=False, model_class=ClapTextModelWithProjection, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ClapLargeTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
@ -551,3 +551,7 @@ class Hunyuan3Dv2mini(LatentFormat):
|
||||
class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class HunyuanFoley(LatentFormat):
|
||||
latent_dimensions = 128
|
||||
latent_channels = 1024
|
||||
921
comfy/ldm/hunyuan_foley/model.py
Normal file
921
comfy/ldm/hunyuan_foley/model.py
Normal file
@ -0,0 +1,921 @@
|
||||
from typing import List, Tuple, Optional, Union, Dict
|
||||
from functools import partial
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention as attention
|
||||
from comfy.ldm.aura.mmdit import TimestepEmbedder as TimestepEmbedderParent
|
||||
from comfy.ldm.hydit.posemb_layers import get_1d_rotary_pos_embed
|
||||
|
||||
from typing import Union, Tuple
|
||||
|
||||
# to get exact matching results
|
||||
# only difference is the upscale to float32
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
class TimestepEmbedder(TimestepEmbedderParent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||
self.w2 = operations.Linear(hidden_dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
||||
ndim = x.ndim
|
||||
if head_first:
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
head_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
class ConditionProjection(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size, dtype=None, device=None, operations = None):
|
||||
factory_kwargs = {'dtype': dtype, 'device': device}
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
||||
self.act_1 = nn.SiLU()
|
||||
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
||||
|
||||
def forward(self, caption):
|
||||
return self.linear_2(self.act_1(self.linear_1(caption)))
|
||||
|
||||
class PatchEmbed1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=1,
|
||||
in_chans=768,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations = None
|
||||
):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = operations.Conv1d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
# avoid classifying as wrapper to work with operations.conv1d
|
||||
class ChannelLastConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, bias=True, kernel_size = 3, padding = 0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
operations = operations or nn
|
||||
underlying = operations.Conv1d(
|
||||
in_channels, out_channels, kernel_size = kernel_size, padding = padding,
|
||||
bias=bias, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
self.register_parameter("weight", underlying.weight)
|
||||
if getattr(underlying, "bias", None) is not None:
|
||||
self.register_parameter("bias", underlying.bias)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
object.__setattr__(self, "_underlying", underlying)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._underlying(x.permute(0, 2, 1))
|
||||
return x.permute(0, 2, 1)
|
||||
|
||||
class ConvMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations = None
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
|
||||
self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
|
||||
self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
if x.ndim == 3:
|
||||
shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
|
||||
scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale)
|
||||
elif scale is None:
|
||||
return x + shift
|
||||
else:
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class ModulateDiT(nn.Module):
|
||||
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations = None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.act = nn.SiLU()
|
||||
self.linear = operations.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.act(x))
|
||||
|
||||
class FinalLayer1D(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, device=None, dtype=None, operations = None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels=None,
|
||||
out_features=None,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
bias = (bias, bias)
|
||||
drop_probs = (drop, drop)
|
||||
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
|
||||
self.act = nn.GELU(approximate="tanh")
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
return self.drop2(self.fc2(self.norm(self.drop1(self.act(self.fc1(x))))))
|
||||
|
||||
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
|
||||
return grid
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
|
||||
):
|
||||
|
||||
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
|
||||
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
freq_scaling=freq_scaling,
|
||||
)
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1)
|
||||
return emb
|
||||
|
||||
def apply_gate(x, gate = None):
|
||||
if gate is None:
|
||||
return x
|
||||
if gate.ndim == 2 and x.ndim == 3:
|
||||
gate = gate.unsqueeze(1)
|
||||
return x * gate
|
||||
|
||||
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
|
||||
B, N1, H, C = x1.shape
|
||||
B, N2, H, C = x2.shape
|
||||
assert x1.ndim == x2.ndim == 4
|
||||
|
||||
if N1 != N2:
|
||||
x2 = x2.view(B, N2, -1).transpose(1, 2)
|
||||
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
|
||||
x2 = x2.transpose(1, 2).view(B, N1, H, C)
|
||||
x = torch.stack((x1, x2), dim=2)
|
||||
x = x.reshape(B, N1 * 2, H, C)
|
||||
return x
|
||||
|
||||
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
|
||||
B, N, H, C = x.shape
|
||||
assert N % 2 == 0 and N // 2 == len1
|
||||
|
||||
x = x.reshape(B, -1, 2, H, C)
|
||||
x1 = x[:, :, 0]
|
||||
x2 = x[:, :, 1]
|
||||
if x2.shape[1] != len2:
|
||||
x2 = x2.view(B, len1, H * C).transpose(1, 2)
|
||||
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
|
||||
x2 = x2.transpose(1, 2).view(B, len2, H, C)
|
||||
return x1, x2
|
||||
|
||||
def apply_modulated_block(x, norm_layer, shift, scale, mlp_layer, gate):
|
||||
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
|
||||
return x + apply_gate(mlp_layer(x_mod), gate=gate)
|
||||
|
||||
def prepare_self_attn_qkv(x, norm_layer, qkv_layer, q_norm, k_norm, shift, scale, num_heads):
|
||||
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
|
||||
qkv = qkv_layer(x_mod)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=num_heads)
|
||||
|
||||
q = q_norm(q).to(v)
|
||||
k = k_norm(k).to(v)
|
||||
return q, k, v
|
||||
|
||||
class TwoStreamCABlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: bool = True,
|
||||
qkv_bias: bool = False,
|
||||
interleaved_audio_visual_rope: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
operations = None
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
head_dim = hidden_size // num_heads
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
|
||||
|
||||
self.audio_mod = ModulateDiT(hidden_size, factor=9, operations = operations, **factory_kwargs)
|
||||
self.audio_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.audio_self_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
def make_qk_norm(name: str):
|
||||
layer = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
setattr(self, name, layer)
|
||||
|
||||
for name in ["v_cond_attn_q_norm", "v_cond_attn_k_norm", "audio_cross_q_norm",
|
||||
"v_cond_cross_q_norm", "text_cross_k_norm", "audio_self_q_norm", "audio_self_k_norm"]:
|
||||
make_qk_norm(name)
|
||||
|
||||
self.audio_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
self.v_cond_mod = ModulateDiT(hidden_size, factor = 9, operations = operations, **factory_kwargs)
|
||||
self.v_cond_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.v_cond_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
self.v_cond_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
self.max_text_len = 100
|
||||
self.rope_dim_list = None
|
||||
|
||||
self.audio_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.v_cond_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.audio_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
self.v_cond_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
self.text_cross_kv = operations.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
self.audio_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
self.v_cond_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
self.audio_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.audio_mlp = MLP(
|
||||
hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs
|
||||
)
|
||||
|
||||
self.v_cond_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.v_cond_mlp = MLP(
|
||||
hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs
|
||||
)
|
||||
|
||||
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
|
||||
target_ndim = 1 # n-d RoPE
|
||||
rope_sizes = [text_len]
|
||||
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
|
||||
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
|
||||
rope_dim_list=rope_dim_list,
|
||||
start=rope_sizes,
|
||||
theta=10000,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1.0,
|
||||
)
|
||||
return text_freqs_cos, text_freqs_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
v_cond: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
freqs_cis: tuple = None,
|
||||
v_freqs_cis: tuple = None,
|
||||
sync_vec: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
|
||||
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
|
||||
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
|
||||
) = self.audio_mod(sync_vec if sync_vec is not None else vec).chunk(9, dim=-1)
|
||||
|
||||
(
|
||||
v_cond_mod1_shift,
|
||||
v_cond_mod1_scale,
|
||||
v_cond_mod1_gate,
|
||||
v_cond_mod2_shift,
|
||||
v_cond_mod2_scale,
|
||||
v_cond_mod2_gate,
|
||||
v_cond_mod3_shift,
|
||||
v_cond_mod3_scale,
|
||||
v_cond_mod3_gate,
|
||||
) = self.v_cond_mod(vec).chunk(9, dim=-1)
|
||||
|
||||
audio_q, audio_k, audio_v = prepare_self_attn_qkv(
|
||||
audio, self.audio_norm1, self.audio_self_attn_qkv,
|
||||
self.audio_self_q_norm, self.audio_self_k_norm,
|
||||
audio_mod1_shift, audio_mod1_scale, self.num_heads
|
||||
)
|
||||
|
||||
v_cond_q, v_cond_k, v_cond_v = prepare_self_attn_qkv(
|
||||
v_cond, self.v_cond_norm1, self.v_cond_attn_qkv,
|
||||
self.v_cond_attn_q_norm, self.v_cond_attn_k_norm,
|
||||
v_cond_mod1_shift, v_cond_mod1_scale, self.num_heads
|
||||
)
|
||||
|
||||
# Apply RoPE if needed for audio and visual
|
||||
if freqs_cis is not None:
|
||||
if not self.interleaved_audio_visual_rope:
|
||||
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
|
||||
audio_q, audio_k = audio_qq, audio_kk
|
||||
else:
|
||||
ori_audio_len = audio_q.shape[1]
|
||||
ori_v_con_len = v_cond_q.shape[1]
|
||||
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
|
||||
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
|
||||
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
|
||||
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
|
||||
)
|
||||
audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
|
||||
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
|
||||
)
|
||||
audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
|
||||
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
|
||||
)
|
||||
audio_q, audio_k = audio_qq, audio_kk
|
||||
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
||||
|
||||
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
|
||||
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
|
||||
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
|
||||
|
||||
q = torch.cat((v_cond_q, audio_q), dim=1)
|
||||
k = torch.cat((v_cond_k, audio_k), dim=1)
|
||||
v = torch.cat((v_cond_v, audio_v), dim=1)
|
||||
|
||||
# TODO: look further into here
|
||||
if attention.__name__ == "attention_pytorch":
|
||||
q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
|
||||
|
||||
attn = attention(q, k, v, heads = self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
||||
|
||||
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
|
||||
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
audio_q = self.prepare_modulated_query(audio, self.audio_norm2, self.audio_cross_q,
|
||||
self.audio_cross_q_norm, audio_mod2_shift, audio_mod2_scale,
|
||||
self.num_heads, self.rope_dim_list)
|
||||
|
||||
v_cond_q = self.prepare_modulated_query(v_cond, self.v_cond_norm2, self.v_cond_cross_q,
|
||||
self.v_cond_cross_q_norm, v_cond_mod2_shift, v_cond_mod2_scale,
|
||||
self.num_heads, self.rope_dim_list)
|
||||
|
||||
text_kv = self.text_cross_kv(cond)
|
||||
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
|
||||
text_k = self.text_cross_k_norm(text_k).to(text_v)
|
||||
|
||||
text_len = text_k.shape[1]
|
||||
|
||||
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
|
||||
rope_dim_list=self.rope_dim_list)
|
||||
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
|
||||
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
|
||||
|
||||
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
|
||||
|
||||
if attention.__name__ == "attention_pytorch":
|
||||
v_cond_audio_q, text_k, text_v = [t.transpose(1, 2) for t in (v_cond_audio_q, text_k, text_v)]
|
||||
|
||||
cross_attn = attention(v_cond_audio_q, text_k, text_v, self.num_heads, skip_reshape = True)
|
||||
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
|
||||
|
||||
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
|
||||
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
|
||||
|
||||
audio = apply_modulated_block(audio, self.audio_norm3, audio_mod3_shift, audio_mod3_scale, self.audio_mlp, audio_mod3_gate)
|
||||
v_cond = apply_modulated_block(v_cond, self.v_cond_norm3, v_cond_mod3_shift, v_cond_mod3_scale, self.v_cond_mlp, v_cond_mod3_gate)
|
||||
|
||||
return audio, cond, v_cond
|
||||
|
||||
def prepare_modulated_query(self, x, norm_layer, q_layer, q_norm_layer, shift, scale, num_heads, rope_dim_list):
|
||||
|
||||
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
|
||||
q = q_layer(x_mod)
|
||||
|
||||
q = rearrange(q, "B L (H D) -> B L H D", H=num_heads)
|
||||
q = q_norm_layer(q)
|
||||
|
||||
head_dim = q.shape[-1]
|
||||
freqs_cos, freqs_sin = self.build_rope_for_text(q.shape[1], head_dim, rope_dim_list)
|
||||
freqs_cis = (freqs_cos.to(q.device), freqs_sin.to(q.device))
|
||||
|
||||
q = apply_rotary_emb(q, q, freqs_cis, head_first=False)[0]
|
||||
|
||||
return q
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
operations = None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.modulation = ModulateDiT(
|
||||
hidden_size=hidden_size,
|
||||
factor=6,
|
||||
operations = operations,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.linear_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=True)
|
||||
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, operations = operations, **factory_kwargs)
|
||||
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, operations = operations, **factory_kwargs)
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs)
|
||||
self.q_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs)
|
||||
self.k_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs)
|
||||
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
|
||||
|
||||
modulation = self.modulation(cond)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
|
||||
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
|
||||
|
||||
qkv = self.linear_qkv(x_norm1)
|
||||
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = [t.squeeze(-1) for t in (q, k, v)]
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
|
||||
|
||||
q, k, v = [t.contiguous() for t in (q, k, v)]
|
||||
|
||||
out = attention(q, k, v, self.num_heads, skip_output_reshape = True, skip_reshape = True)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
|
||||
|
||||
x = x + apply_gate(self.linear1(out),gate=gate_msa)
|
||||
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
|
||||
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
class HunyuanVideoFoley(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_args,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
operations = None
|
||||
):
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.depth_triple_blocks = 18
|
||||
self.depth_single_blocks = 36
|
||||
|
||||
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", True)
|
||||
|
||||
self.condition_dim = model_args.get("condition_dim", 768)
|
||||
|
||||
self.patch_size = model_args.get("patch_size", 1)
|
||||
self.visual_in_channels = model_args.get("clip_dim", 768)
|
||||
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
|
||||
self.out_channels = self.audio_vae_latent_dim
|
||||
self.unpatchify_channels = self.out_channels
|
||||
|
||||
self.num_heads = model_args.get("num_heads", 12)
|
||||
self.hidden_size = model_args.get("hidden_size", 1536)
|
||||
self.rope_dim_list = model_args.get("rope_dim_list", None)
|
||||
self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
|
||||
|
||||
self.qkv_bias = model_args.get("qkv_bias", True)
|
||||
self.qk_norm = model_args.get("qk_norm", True)
|
||||
|
||||
# sync condition things
|
||||
self.sync_modulation = model_args.get("sync_modulation", False)
|
||||
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", True)
|
||||
self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
|
||||
self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
|
||||
|
||||
self.clip_len = model_args.get("clip_length", 64)
|
||||
self.sync_len = model_args.get("sync_length", 192)
|
||||
|
||||
self.patch_size = 1
|
||||
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, operations=operations, **factory_kwargs)
|
||||
self.visual_proj = SwiGLU(dim = self.visual_in_channels, hidden_dim = self.hidden_size, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.cond_in = ConditionProjection(
|
||||
self.condition_dim, self.hidden_size, operations=operations, **factory_kwargs
|
||||
)
|
||||
|
||||
self.time_in = TimestepEmbedder(self.hidden_size, operations = operations, **factory_kwargs)
|
||||
|
||||
# visual sync embedder if needed
|
||||
if self.sync_in_ksz == 1:
|
||||
sync_in_padding = 0
|
||||
elif self.sync_in_ksz == 3:
|
||||
sync_in_padding = 1
|
||||
else:
|
||||
raise ValueError
|
||||
if self.sync_modulation or self.add_sync_feat_to_audio:
|
||||
self.sync_in = nn.Sequential(
|
||||
operations.Linear(self.sync_feat_dim, self.hidden_size, **factory_kwargs),
|
||||
nn.SiLU(),
|
||||
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding, operations=operations, **factory_kwargs),
|
||||
)
|
||||
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim), **factory_kwargs))
|
||||
|
||||
self.triple_blocks = nn.ModuleList(
|
||||
[
|
||||
TwoStreamCABlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qk_norm=self.qk_norm,
|
||||
qkv_bias=self.qkv_bias,
|
||||
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
|
||||
operations=operations,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(self.depth_triple_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
operations=operations,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(self.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer1D(
|
||||
self.hidden_size, self.patch_size, self.out_channels, operations = operations,**factory_kwargs
|
||||
)
|
||||
self.unpatchify_channels = self.out_channels
|
||||
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
|
||||
|
||||
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
|
||||
len = len if len is not None else self.clip_len
|
||||
if bs is None:
|
||||
return self.empty_clip_feat.expand(len, -1) # 15s
|
||||
else:
|
||||
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
|
||||
|
||||
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
|
||||
len = len if len is not None else self.sync_len
|
||||
if bs is None:
|
||||
return self.empty_sync_feat.expand(len, -1)
|
||||
else:
|
||||
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
|
||||
|
||||
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
|
||||
target_ndim = 1 # n-d RoPE
|
||||
rope_sizes = [audio_emb_len]
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
rope_dim_list=rope_dim_list,
|
||||
start=rope_sizes,
|
||||
theta=10000,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1.0,
|
||||
)
|
||||
|
||||
target_ndim = 1
|
||||
rope_sizes = [visual_cond_len]
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
|
||||
rope_dim_list=rope_dim_list,
|
||||
start=rope_sizes,
|
||||
theta=10000,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1.0,
|
||||
freq_scaling=1.0 * audio_emb_len / visual_cond_len,
|
||||
)
|
||||
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
|
||||
|
||||
def build_rope_for_interleaved_audio_visual(self, total_len):
|
||||
target_ndim = 1 # n-d RoPE
|
||||
rope_sizes = [total_len]
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
rope_dim_list=rope_dim_list,
|
||||
start=rope_sizes,
|
||||
theta=10000,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1.0,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
full_cond: torch.Tensor,
|
||||
transformer_options = {},
|
||||
drop_visual: Optional[List[bool]] = None,
|
||||
):
|
||||
audio = x
|
||||
bs, _, ol = x.shape
|
||||
tl = ol // self.patch_size
|
||||
|
||||
condition, uncondition = torch.chunk(2, full_cond)
|
||||
uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition)
|
||||
clip_feat, sync_feat, cond = torch.chunk(3, condition)
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond])
|
||||
|
||||
if drop_visual is not None:
|
||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
|
||||
|
||||
vec = self.time_in(t)
|
||||
sync_vec = None
|
||||
if self.add_sync_feat_to_audio:
|
||||
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb.to(sync_feat.device)
|
||||
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim)
|
||||
sync_feat = self.sync_in.to(sync_feat.device)(sync_feat)
|
||||
add_sync_feat_to_audio = (
|
||||
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
|
||||
)
|
||||
|
||||
cond = self.cond_in(cond)
|
||||
cond_seq_len = cond.shape[1]
|
||||
|
||||
audio = self.audio_embedder(x)
|
||||
audio_seq_len = audio.shape[1]
|
||||
v_cond = self.visual_proj(clip_feat)
|
||||
v_cond_seq_len = v_cond.shape[1]
|
||||
attn_mask = None
|
||||
|
||||
|
||||
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
|
||||
v_freqs_cos = v_freqs_sin = None
|
||||
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
||||
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
|
||||
|
||||
if self.add_sync_feat_to_audio:
|
||||
add_sync_layer = 0
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
def block_wrap(**kwargs):
|
||||
return block(**kwargs)
|
||||
|
||||
for layer_num, block in enumerate(self.triple_blocks):
|
||||
if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
|
||||
audio = audio + add_sync_feat_to_audio
|
||||
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
|
||||
if ("triple_block", layer_num) in blocks_replace:
|
||||
audio, cond, v_cond = blocks_replace[("triple_block", layer_num)]({
|
||||
"audio": triple_block_args[0],
|
||||
"cond": triple_block_args[1],
|
||||
"v_cond": triple_block_args[2],
|
||||
"attn_mask": triple_block_args[3],
|
||||
"vec": triple_block_args[4],
|
||||
"freqs_cis": triple_block_args[5],
|
||||
"v_freqs_cis": triple_block_args[6],
|
||||
"sync_vec": triple_block_args[7]
|
||||
}, {"original_block": block_wrap})
|
||||
else:
|
||||
audio, cond, v_cond = block(*triple_block_args)
|
||||
|
||||
x = audio
|
||||
if sync_vec is not None:
|
||||
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
|
||||
vec = torch.cat((vec, sync_vec), dim=1)
|
||||
|
||||
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
|
||||
if self.add_sync_feat_to_audio:
|
||||
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
|
||||
if len(self.single_blocks) > 0:
|
||||
for layer_num, block in enumerate(self.single_blocks):
|
||||
single_block_args = [
|
||||
x,
|
||||
vec,
|
||||
(freqs_cos, freqs_sin),
|
||||
]
|
||||
if ("single_block", layer_num) in blocks_replace:
|
||||
x = blocks_replace[("single_block", layer_num)]({
|
||||
"x": single_block_args[0],
|
||||
"vec": single_block_args[1],
|
||||
"freqs_cis": single_block_args[2]
|
||||
}, {"original_block": block_wrap})
|
||||
else:
|
||||
x = block(*single_block_args)
|
||||
|
||||
audio = x
|
||||
|
||||
if sync_vec is not None:
|
||||
vec = sync_vec
|
||||
audio = self.final_layer(audio, vec)
|
||||
audio = self.unpatchify1d(audio, tl)
|
||||
|
||||
uncond, cond = torch.chunk(2, audio)
|
||||
return torch.cat([cond, uncond])
|
||||
|
||||
def unpatchify1d(self, x, l):
|
||||
c = self.unpatchify_channels
|
||||
p = self.patch_size
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], l, p, c))
|
||||
x = torch.einsum("ntpc->nctp", x)
|
||||
audio = x.reshape(shape=(x.shape[0], c, l * p))
|
||||
return audio
|
||||
991
comfy/ldm/hunyuan_foley/syncformer.py
Normal file
991
comfy/ldm/hunyuan_foley/syncformer.py
Normal file
@ -0,0 +1,991 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat, einops
|
||||
|
||||
from comfy.ldm.hunyuan3dv2_1.hunyuandit import MLP as Mlp
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
||||
from comfy.ldm.modules.attention import optimized_attention, TransformerEncoderComfyv
|
||||
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
|
||||
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
class Config:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
img_size = img_size if type(img_size) is tuple else (img_size, img_size)
|
||||
patch_size = img_size if type(patch_size) is tuple else (patch_size, patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
patch_size=16,
|
||||
z_block_size=2,
|
||||
embed_dim=768,
|
||||
flatten=True,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.height = img_size // patch_size
|
||||
self.width = img_size // patch_size
|
||||
self.z_block_size = z_block_size
|
||||
self.proj = operations.Conv3d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=(z_block_size, patch_size, patch_size),
|
||||
stride=(z_block_size, patch_size, patch_size),
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
def qkv_attn(q, k, v, heads):
|
||||
bh, seq_q, dim_head = q.shape
|
||||
b = bh // heads
|
||||
|
||||
# (b*heads, seq, dim) -> (b, heads, seq, dim)
|
||||
q2 = q.view(b, heads, seq_q, dim_head)
|
||||
k2 = k.view(b, heads, k.shape[1], dim_head)
|
||||
v2 = v.view(b, heads, v.shape[1], dim_head)
|
||||
|
||||
out = optimized_attention(q2, k2, v2, heads=heads, skip_reshape=True)
|
||||
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(b * heads, seq_q, dim_head)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DividedAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
|
||||
h = self.num_heads
|
||||
|
||||
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
||||
|
||||
cls_out = qkv_attn(cls_q, k, v, self.num_heads)
|
||||
|
||||
q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_))
|
||||
|
||||
r = q_.shape[0] // cls_k.shape[0]
|
||||
cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v))
|
||||
|
||||
k_ = torch.cat((cls_k, k_), dim=1)
|
||||
v_ = torch.cat((cls_v, v_), dim=1)
|
||||
|
||||
out = qkv_attn(q_, k_, v_, self.num_heads)
|
||||
out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims)
|
||||
|
||||
out = torch.cat((cls_out, out), dim=1)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
|
||||
x = self.proj(out)
|
||||
return x
|
||||
|
||||
class DividedSpaceTimeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=768,
|
||||
num_heads=12,
|
||||
qkv_bias=False,
|
||||
norm_layer=nn.LayerNorm,
|
||||
device = None, dtype = None, operations = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device":device, "dtype": dtype}
|
||||
|
||||
self.einops_from_space = "b (f n) d"
|
||||
self.einops_to_space = "(b f) n d"
|
||||
self.einops_from_time = "b (f n) d"
|
||||
self.einops_to_time = "(b n) f d"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, operations = operations, **factory_kwargs)
|
||||
|
||||
self.timeattn = DividedAttention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, operations=operations, **factory_kwargs
|
||||
)
|
||||
|
||||
self.drop_path = nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(width = dim, operations = operations, device=device, dtype=dtype)
|
||||
self.norm3 = norm_layer(dim)
|
||||
|
||||
def forward(self, x, seq_len=196, num_frames=8, tok_mask: torch.Tensor = None):
|
||||
time_output = self.timeattn(
|
||||
self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask
|
||||
)
|
||||
time_residual = x + time_output
|
||||
|
||||
space_output = self.attn(
|
||||
self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask
|
||||
)
|
||||
space_residual = time_residual + self.drop_path(space_output)
|
||||
|
||||
x = space_residual
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
class MotionFormer(nn.Module):
|
||||
def __init__(self, device = None, dtype = None, operations = None):
|
||||
super().__init__()
|
||||
self.APPROX_ATTN_TYPE = "none"
|
||||
self.APPROX_ATTN_DIM = 64
|
||||
self.img_size = 224
|
||||
self.patch_size = 16
|
||||
self.in_chans = 3
|
||||
self.num_classes = 174
|
||||
self.embed_dim = 768
|
||||
self.depth = 12
|
||||
self.num_heads = 12
|
||||
self.mlp_ratio = 4
|
||||
self.qkv_bias = True
|
||||
self.drop_rate = 0.0
|
||||
self.drop_path_rate = 0.2
|
||||
self.temporal_resolution = 8
|
||||
self.use_mlp = True
|
||||
self.num_features = self.embed_dim
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.attn_drop_rate = 0.0
|
||||
self.factorize_space_time = True
|
||||
|
||||
# Patch Embedding
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
# 3D Patch Embedding
|
||||
self.patch_embed_3d = PatchEmbed3D(
|
||||
img_size=self.img_size,
|
||||
patch_size=self.patch_size,
|
||||
in_chans=self.in_chans,
|
||||
embed_dim=self.embed_dim,
|
||||
z_block_size = 2,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data)
|
||||
|
||||
# Number of patches
|
||||
self.num_patches = self.patch_embed.num_patches * self.temporal_resolution
|
||||
|
||||
# CLS token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, device=device, dtype=dtype))
|
||||
|
||||
# Positional embedding
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim, device=device, dtype=dtype))
|
||||
self.pos_drop = nn.Dropout(p=0.0)
|
||||
|
||||
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim, device=device, dtype=dtype))
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
DividedSpaceTimeBlock(
|
||||
dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
qkv_bias=self.qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = norm_layer(self.embed_dim)
|
||||
|
||||
self.pre_logits = nn.Identity()
|
||||
|
||||
transf_enc_layer_kwargs = dict(
|
||||
d_model=self.embed_dim,
|
||||
nhead=self.num_heads,
|
||||
activation=nn.GELU(),
|
||||
batch_first=True,
|
||||
dim_feedforward=self.mlp_ratio * self.embed_dim,
|
||||
dropout=self.drop_rate,
|
||||
layer_norm_eps=1e-6,
|
||||
norm_first=True,
|
||||
)
|
||||
self.spatial_attn_agg = SpatialTransformerEncoderLayer(device = device, dtype=dtype, operations=operations,**transf_enc_layer_kwargs)
|
||||
self.temp_attn_agg = nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
# apply patching on input
|
||||
x = self.patch_embed_3d(x)
|
||||
tok_mask = None
|
||||
|
||||
# Append CLS token
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
new_pos_embed = self.pos_embed
|
||||
npatch = self.patch_embed.num_patches
|
||||
|
||||
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
||||
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
|
||||
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
|
||||
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
||||
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
||||
x = x + total_pos_embed
|
||||
|
||||
# Apply positional dropout
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# Encoding using transformer layers
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(
|
||||
x,
|
||||
seq_len=npatch,
|
||||
num_frames=self.temporal_resolution,
|
||||
tok_mask=tok_mask,
|
||||
)
|
||||
|
||||
return x, tok_mask
|
||||
|
||||
def forward(self, x):
|
||||
B, S, C, T, H, W = x.shape
|
||||
|
||||
orig_shape = (B, S, C, T, H, W)
|
||||
x = x.view(B * S, C, T, H, W) # flatten batch and segments
|
||||
x = self.forward_segments(x, orig_shape=orig_shape)
|
||||
x = x.view(B, S, *x.shape[1:])
|
||||
|
||||
return x
|
||||
|
||||
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
|
||||
x, x_mask = self.forward_features(x)
|
||||
|
||||
x = x[:, 1:, :]
|
||||
x = self.norm(x)
|
||||
x = self.pre_logits(x)
|
||||
if self.factorize_space_time:
|
||||
x = self.restore_spatio_temp_dims(x, orig_shape)
|
||||
|
||||
x = self.spatial_attn_agg(x, x_mask)
|
||||
x = self.temp_attn_agg(x)
|
||||
|
||||
return x
|
||||
|
||||
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
||||
|
||||
B, S, C, T, H, W = orig_shape
|
||||
D = self.embed_dim
|
||||
|
||||
# num patches in each dimension
|
||||
t = T // self.patch_embed_3d.z_block_size
|
||||
h = self.patch_embed_3d.height
|
||||
w = self.patch_embed_3d.width
|
||||
|
||||
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
||||
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
|
||||
|
||||
return feats
|
||||
|
||||
class BaseEncoderLayer(TransformerEncoderComfyv):
|
||||
def __init__(
|
||||
self,
|
||||
add_pos_emb: bool = False,
|
||||
pos_emb_drop: float = None,
|
||||
pos_max_len: int = None,
|
||||
device = None,
|
||||
dtype = None, operations = None,
|
||||
*args, **kwargs
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__(operations = operations, *args, **kwargs, **factory_kwargs)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim, **factory_kwargs))
|
||||
|
||||
self.add_pos_emb = add_pos_emb
|
||||
if add_pos_emb:
|
||||
self.pos_max_len = 1 + pos_max_len
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim, **factory_kwargs))
|
||||
self.pos_drop = nn.Dropout(pos_emb_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
|
||||
batch_dim = x.shape[0]
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_dim, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=-2)
|
||||
if x_mask is not None:
|
||||
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device)
|
||||
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1)
|
||||
B, N = x_mask_w_cls.shape
|
||||
x_mask_w_cls = (
|
||||
x_mask_w_cls.reshape(B, 1, 1, N)
|
||||
.expand(-1, self.self_attn.num_heads, N, -1)
|
||||
.reshape(B * self.self_attn.num_heads, N, N)
|
||||
)
|
||||
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool"
|
||||
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
|
||||
else:
|
||||
x_mask_w_cls = None
|
||||
|
||||
# add positional embedding
|
||||
if self.add_pos_emb:
|
||||
seq_len = x.shape[1]
|
||||
assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})"
|
||||
x = x + self.pos_emb[:, :seq_len, :]
|
||||
x = self.pos_drop(x)
|
||||
|
||||
x = super().forward(src=x, src_mask=x_mask_w_cls)
|
||||
|
||||
x = x[:, 0, :]
|
||||
|
||||
return x
|
||||
|
||||
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
BS, D, t, h, w = x.shape
|
||||
|
||||
x = rearrange(x, "BS D t h w -> (BS t) (h w) D")
|
||||
if x_mask is not None:
|
||||
x_mask = rearrange(x_mask, "BS t h w -> (BS t) (h w)")
|
||||
|
||||
x = super().forward(x=x, x_mask=x_mask)
|
||||
|
||||
x = rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t)
|
||||
|
||||
return x
|
||||
|
||||
class AST(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
max_spec_t: int = None,
|
||||
factorize_freq_time: bool = None,
|
||||
max_segments: int = None,
|
||||
device = None, dtype = None, operations = None
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.extract_features = True
|
||||
self.max_spec_t = max_spec_t
|
||||
self.max_segments = max_segments
|
||||
|
||||
self.config = ASTConfig()
|
||||
self.config.num_labels = 527
|
||||
|
||||
self.ast = ASTModel(self.config, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.feat_type = "last_hidden_state"
|
||||
self.factorize_freq_time = factorize_freq_time
|
||||
|
||||
transf_enc_layer_kwargs = dict(
|
||||
d_model=self.config.hidden_size,
|
||||
nhead=self.config.num_attention_heads,
|
||||
dim_feedforward=self.config.intermediate_size,
|
||||
activation=torch.nn.GELU(),
|
||||
batch_first=True,
|
||||
dropout=self.config.attention_probs_dropout_prob,
|
||||
layer_norm_eps=1e-6,
|
||||
norm_first=True,
|
||||
)
|
||||
if factorize_freq_time:
|
||||
self.feat_type = "last_hidden_state"
|
||||
self.freq_attn_agg = FrequencyTransformerEncoderLayer(operations = operations, **transf_enc_layer_kwargs, **factory_kwargs)
|
||||
self.temp_attn_agg = torch.nn.Identity()
|
||||
|
||||
self.device = device
|
||||
|
||||
self.patch_position_emb()
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
B, S, T, F = x.shape
|
||||
|
||||
if for_loop:
|
||||
assert cont_mask is None, "cont_mask is not supported with for_loop=True"
|
||||
orig_shape_s = (B, 1, T, F)
|
||||
x = torch.cat(
|
||||
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1
|
||||
)
|
||||
else:
|
||||
orig_shape = (B, S, T, F)
|
||||
x = x.view(B * S, T, F)
|
||||
if cont_mask is not None:
|
||||
cont_mask = cont_mask.reshape(B * S, T, F)
|
||||
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
|
||||
x = x.view(B, S, *x.shape[1:])
|
||||
|
||||
global_x = None
|
||||
|
||||
return x, global_x
|
||||
|
||||
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
|
||||
|
||||
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
|
||||
|
||||
if self.extract_features:
|
||||
x = self.get_features_by_type(x)
|
||||
if self.factorize_freq_time:
|
||||
x = self.restore_freq_temp_dims(x, orig_shape)
|
||||
if cont_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
|
||||
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape)
|
||||
x_mask = x_mask[:, 0, :, :]
|
||||
else:
|
||||
x_mask = None
|
||||
x = self.freq_attn_agg(x, x_mask)
|
||||
x = self.temp_attn_agg(x)
|
||||
else:
|
||||
x = x["pooler_output"]
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def get_features_by_type(self, x) -> torch.Tensor:
|
||||
return x["last_hidden_state"] # (B, 2+T, D)
|
||||
|
||||
def restore_freq_temp_dims(self, feats, orig_shape: tuple):
|
||||
B, S, T, F = orig_shape
|
||||
D = self.config.hidden_size
|
||||
|
||||
# num patches in each dimension
|
||||
f, t = self.ast.embeddings.get_shape(self.config)
|
||||
|
||||
if self.feat_type == "last_hidden_state":
|
||||
feats = feats[:, 2:, :] # removing CLS and distill tokens
|
||||
|
||||
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
||||
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
|
||||
|
||||
return feats
|
||||
|
||||
def patch_position_emb(self):
|
||||
if self.max_spec_t is not None:
|
||||
self.config.max_length = self.max_spec_t
|
||||
f, t = self.ast.embeddings.get_shape(self.config)
|
||||
shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens
|
||||
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
|
||||
|
||||
def to(self, device):
|
||||
self.device = torch.device(device)
|
||||
return super().to(device)
|
||||
|
||||
|
||||
class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
||||
BS, D, f, t = x.shape
|
||||
|
||||
x = x.permute(0, 3, 2, 1)
|
||||
x = x.reshape(BS * t, f, D)
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.permute(0, 2, 1)
|
||||
x_mask = x_mask.reshape(BS * t, f)
|
||||
|
||||
x = super().forward(x=x, x_mask=x_mask)
|
||||
|
||||
x = x.view(BS, t, D)
|
||||
|
||||
return x
|
||||
|
||||
class ASTEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype))
|
||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = ASTPatchEmbeddings(config, device, dtype, operations)
|
||||
|
||||
frequency_out_dimension, time_out_dimension = self.get_shape(config)
|
||||
num_patches = frequency_out_dimension * time_out_dimension
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size, device=device, dtype=dtype))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
def get_shape(self, config):
|
||||
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
|
||||
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
|
||||
|
||||
return frequency_out_dimension, time_out_dimension
|
||||
|
||||
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = input_values.shape[0]
|
||||
embeddings = self.patch_embeddings(input_values)
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class ASTPatchEmbeddings(nn.Module):
|
||||
def __init__(self, config, device = None, dtype = None, operations = None):
|
||||
super().__init__()
|
||||
|
||||
|
||||
patch_size = config.patch_size
|
||||
frequency_stride = config.frequency_stride
|
||||
time_stride = config.time_stride
|
||||
|
||||
self.projection = operations.Conv2d(
|
||||
1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride), device = device, dtype = dtype
|
||||
)
|
||||
|
||||
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
||||
input_values = input_values.unsqueeze(1)
|
||||
input_values = input_values.transpose(2, 3)
|
||||
embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
|
||||
return embeddings
|
||||
|
||||
|
||||
class ASTSelfAttention(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None:
|
||||
super().__init__()
|
||||
factory_kwargs = { "device": device, "dtype": dtype }
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
|
||||
self.key = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
|
||||
self.value = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
if tok_mask is not None:
|
||||
attn_mask = (tok_mask == 0)
|
||||
attn_mask = attn_mask[:, None, None, :]
|
||||
else:
|
||||
attn_mask = None
|
||||
context_layer = optimized_attention(query_layer, key_layer, value_layer, self.num_attention_heads, mask = attn_mask, skip_output_reshape=True, skip_reshape=True)
|
||||
context_layer = context_layer.view(*query_layer.size())
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
return (context_layer,)
|
||||
|
||||
class ASTSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(config.hidden_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ASTAttention(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.attention = ASTSelfAttention(config, device=device, dtype=dtype, operations=operations)
|
||||
self.output = ASTSelfOutput(config, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
self_outputs = self.attention(hidden_states, tok_mask, head_mask)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class ASTIntermediate(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||
self.intermediate_act_fn = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ASTOutput(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + input_tensor
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ASTLayer(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
factory_kwargs = {"device":device, "dtype":dtype}
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ASTAttention(config, operations = operations, **factory_kwargs)
|
||||
self.intermediate = ASTIntermediate(config, operations=operations, **factory_kwargs)
|
||||
self.output = ASTOutput(config, operations=operations, **factory_kwargs)
|
||||
self.layernorm_before = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs)
|
||||
self.layernorm_after = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states),
|
||||
tok_mask,
|
||||
head_mask,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:]
|
||||
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ASTEncoder(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ASTLayer(config, device, dtype, operations) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tok_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
class ASTModel(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embeddings = ASTEmbeddings(config, device, dtype, operations)
|
||||
self.encoder = ASTEncoder(config, device, dtype, operations)
|
||||
|
||||
self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def get_input_embeddings(self) -> ASTPatchEmbeddings:
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
cont_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_values)
|
||||
|
||||
if cont_mask is not None:
|
||||
indicator = torch.ones_like(input_values).to(input_values.dtype)
|
||||
indicator[~cont_mask] = torch.inf
|
||||
with torch.no_grad():
|
||||
indicator = self.embeddings(indicator)
|
||||
tok_mask = ~torch.isnan(indicator)
|
||||
tok_mask = tok_mask[:, :, 0]
|
||||
else:
|
||||
tok_mask = None
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
tok_mask=tok_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
sequence_output = encoder_outputs
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
|
||||
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
|
||||
|
||||
return (
|
||||
BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
),
|
||||
tok_mask,
|
||||
)
|
||||
|
||||
class ASTMLPHead(nn.Module):
|
||||
def __init__(self, config: ASTConfig, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.dense = operations.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = self.dense(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
class RandInitPositionalEncoding(nn.Module):
|
||||
def __init__(self, block_shape: list, n_embd: int, device = None, dtype = None,):
|
||||
super().__init__()
|
||||
self.block_shape = block_shape
|
||||
self.n_embd = n_embd
|
||||
self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, token_embeddings):
|
||||
return token_embeddings + self.pos_emb
|
||||
|
||||
|
||||
class GlobalTransformer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tok_pdrop=0.0,
|
||||
embd_pdrop=0.1,
|
||||
resid_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
n_layer=3,
|
||||
n_head=8,
|
||||
n_embd=768,
|
||||
pos_emb_block_shape=[
|
||||
198,
|
||||
],
|
||||
n_off_head_out=21,
|
||||
device = None, dtype = None, operations = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device":device, "dtype": dtype}
|
||||
self.config = Config(
|
||||
embd_pdrop=embd_pdrop,
|
||||
resid_pdrop=resid_pdrop,
|
||||
attn_pdrop=attn_pdrop,
|
||||
n_layer=n_layer,
|
||||
n_head=n_head,
|
||||
n_embd=n_embd,
|
||||
)
|
||||
# input norm
|
||||
self.vis_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
|
||||
self.aud_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
|
||||
# aux tokens
|
||||
self.OFF_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
|
||||
self.MOD_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
|
||||
# whole token dropout
|
||||
self.tok_pdrop = tok_pdrop
|
||||
self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
|
||||
self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
|
||||
# maybe add pos emb
|
||||
self.pos_emb_cfg = RandInitPositionalEncoding(
|
||||
block_shape=pos_emb_block_shape,
|
||||
n_embd=n_embd,
|
||||
)
|
||||
# the stem
|
||||
self.drop = torch.nn.Dropout(embd_pdrop)
|
||||
self.blocks = operations.Sequential(*[Block(self.config, operations=operations, **factory_kwargs) for _ in range(n_layer)])
|
||||
# pre-output norm
|
||||
self.ln_f = operations.LayerNorm(n_embd)
|
||||
# maybe add a head
|
||||
self.off_head = operations.Linear(in_features=n_embd, out_features=n_off_head_out)
|
||||
|
||||
def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
|
||||
B, Sv, D = v.shape
|
||||
B, Sa, D = a.shape
|
||||
|
||||
off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
|
||||
mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
|
||||
|
||||
v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
|
||||
|
||||
if self.tok_pdrop > 0:
|
||||
v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
|
||||
|
||||
x = torch.cat((off_tok, v, mod_tok, a), dim=1)
|
||||
if hasattr(self, "pos_emb_cfg"):
|
||||
x = self.pos_emb_cfg(x)
|
||||
|
||||
x = self.drop(x)
|
||||
x = self.blocks(x)
|
||||
x = self.ln_f(x)
|
||||
|
||||
if attempt_to_apply_heads and hasattr(self, "off_head"):
|
||||
x = self.off_head(x[:, 0, :])
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, config, device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
self.key = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
|
||||
self.query = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
|
||||
self.value = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
|
||||
|
||||
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
self.proj = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
|
||||
self.n_head = config.n_head
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
||||
|
||||
y = optimized_attention(q, k, v, self.n_head, skip_reshape=True)
|
||||
|
||||
y = self.resid_drop(self.proj(y))
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, device, dtype, operations):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device":device, "dtype":dtype}
|
||||
self.ln1 = operations.LayerNorm(config.n_embd, **factory_kwargs)
|
||||
self.ln2 = operations.LayerNorm(config.n_embd, **factory_kwargs)
|
||||
self.attn = SelfAttention(config, device, dtype, operations)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(config.n_embd, 4 * config.n_embd, **factory_kwargs),
|
||||
nn.GELU(),
|
||||
operations.Linear(4 * config.n_embd, config.n_embd, **factory_kwargs),
|
||||
nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
class Synchformer(nn.Module):
|
||||
|
||||
def __init__(self, device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device":device, "dtype":dtype}
|
||||
|
||||
self.vfeat_extractor = MotionFormer(operations = operations, **factory_kwargs)
|
||||
self.afeat_extractor = AST(
|
||||
operations = operations,
|
||||
max_spec_t = 66,
|
||||
factorize_freq_time = True,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
self.vproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs)
|
||||
self.aproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs)
|
||||
self.transformer = GlobalTransformer(
|
||||
tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768, operations=operations, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, vis):
|
||||
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
||||
vis = self.vfeat_extractor(vis)
|
||||
return vis
|
||||
|
||||
def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
|
||||
vis = self.vproj(vis)
|
||||
aud = self.aproj(aud)
|
||||
|
||||
B, S, tv, D = vis.shape
|
||||
B, S, ta, D = aud.shape
|
||||
vis = vis.view(B, S * tv, D)
|
||||
aud = aud.view(B, S * ta, D)
|
||||
|
||||
logits = self.transformer(vis, aud)
|
||||
|
||||
return logits
|
||||
|
||||
def extract_vfeats(self, vis):
|
||||
return self.vfeat_extractor(vis.permute(0, 1, 3, 2, 4, 5))
|
||||
|
||||
def extract_afeats(self, aud):
|
||||
B, S, _, Fa, Ta = aud.shape
|
||||
aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2)
|
||||
aud, _ = self.afeat_extractor(aud)
|
||||
return aud
|
||||
85
comfy/ldm/hunyuan_foley/vae.py
Normal file
85
comfy/ldm/hunyuan_foley/vae.py
Normal file
@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from einops import rearrange
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
|
||||
from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class DAC(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 64,
|
||||
encoder_rates: List[int] = [2, 4, 8, 8],
|
||||
latent_dim: int = None,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 4, 2],
|
||||
sample_rate: int = 44100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
|
||||
|
||||
self.decoder = DACDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
operations = ops
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
class FoleyVae(torch.nn.Module):
|
||||
def __init__(self):
|
||||
self.dac = DAC()
|
||||
self.syncformer = Synchformer(None, None, operations = ops)
|
||||
self.syncformer_preprocess = v2.Compose(
|
||||
[
|
||||
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.CenterCrop(224),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
def decode(self, x, vae_options = {}):
|
||||
return self.dac.decode(x)
|
||||
def encode(self, x):
|
||||
return self.syncformer(x)
|
||||
|
||||
def video_encoding(self, video, step: int):
|
||||
|
||||
if not isinstance(video, torch.Tensor):
|
||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
|
||||
video = self.syncformer_preprocess(video).unsqueeze(0)
|
||||
seg_len = 16
|
||||
t = video.size(1)
|
||||
nseg = max(0, (t - seg_len) // step + 1)
|
||||
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
|
||||
data = torch.stack(clips, dim=1)
|
||||
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
||||
|
||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))
|
||||
@ -158,7 +158,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, freq_scaling = 1.0):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
@ -180,6 +180,10 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
|
||||
if freq_scaling != 1.0:
|
||||
freqs *= freq_scaling
|
||||
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
|
||||
@ -2,6 +2,7 @@ import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
@ -1032,4 +1033,162 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
# comfyui implementation of nn.MultiheadAttention
|
||||
class MultiheadAttentionComfyv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
bias=True,
|
||||
batch_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
# to avoid pytorch checkpoint registeration
|
||||
object.__setattr__(self, "_q_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
|
||||
object.__setattr__(self, "_k_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
|
||||
object.__setattr__(self, "_v_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
|
||||
|
||||
self.out_proj = operations.Linear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# overwriting state dict loading to convert in_proj_weight/bias -> self._q_proj/_k_proj/_v_proj
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
def pop_key(k):
|
||||
return state_dict.pop(k) if k in state_dict else None
|
||||
|
||||
in_proj_w_key = prefix + "in_proj_weight"
|
||||
in_proj_b_key = prefix + "in_proj_bias"
|
||||
|
||||
weight = pop_key(in_proj_w_key)
|
||||
if weight is not None:
|
||||
q_w, k_w, v_w = torch.chunk(weight, 3)
|
||||
self._q_proj.weight.data.copy_(q_w.to(self._q_proj.weight.device, dtype=self._q_proj.weight.dtype))
|
||||
self._k_proj.weight.data.copy_(k_w.to(self._k_proj.weight.device, dtype=self._k_proj.weight.dtype))
|
||||
self._v_proj.weight.data.copy_(v_w.to(self._v_proj.weight.device, dtype=self._v_proj.weight.dtype))
|
||||
|
||||
bias = pop_key(in_proj_b_key)
|
||||
if bias is not None:
|
||||
q_b, k_b, v_b = torch.chunk(bias, 3)
|
||||
if getattr(self._q_proj, "bias", None) is not None:
|
||||
self._q_proj.bias.data.copy_(q_b.to(self._q_proj.bias.device, dtype=self._q_proj.bias.dtype))
|
||||
self._k_proj.bias.data.copy_(k_b.to(self._k_proj.bias.device, dtype=self._k_proj.bias.dtype))
|
||||
self._v_proj.bias.data.copy_(v_b.to(self._v_proj.bias.device, dtype=self._v_proj.bias.dtype))
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, src, attn_mask = None, key_padding_mask = None):
|
||||
|
||||
q = self._q_proj(src)
|
||||
k = self._k_proj(src)
|
||||
v = self._v_proj(src)
|
||||
|
||||
output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask)
|
||||
return self.out_proj(output)
|
||||
|
||||
# comfyui implementation of nn.TransformerEncoderLayer
|
||||
class TransformerEncoderComfyv(nn.Module):
|
||||
def __init__(self,
|
||||
d_model, nhead, dim_feedforward,
|
||||
norm_first = False,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
activation = F.relu,
|
||||
device = None,
|
||||
dtype = None, operations = None, **kwargs):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.norm1 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
self.norm2 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.self_attn = MultiheadAttentionComfyv(
|
||||
embed_dim = d_model,
|
||||
num_heads = nhead,
|
||||
bias = bias,
|
||||
operations = operations,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, src, src_key_padding_mask = None, src_mask = None):
|
||||
src_key_padding_mask = F._canonical_mask(
|
||||
mask=src_key_padding_mask,
|
||||
mask_name="src_key_padding_mask",
|
||||
other_type=F._none_or_dtype(src_mask),
|
||||
other_name="src_mask",
|
||||
target_type=src.dtype,
|
||||
)
|
||||
|
||||
src_mask = F._canonical_mask(
|
||||
mask=src_mask,
|
||||
mask_name="src_mask",
|
||||
other_type=None,
|
||||
other_name="",
|
||||
target_type=src.dtype,
|
||||
check_other=False,
|
||||
)
|
||||
|
||||
x = src
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(
|
||||
self.norm1(x), src_mask, src_key_padding_mask
|
||||
)
|
||||
x = x + self._ff_block(self.norm2(x))
|
||||
else:
|
||||
x = self.norm1(
|
||||
x
|
||||
+ self._sa_block(x, src_mask, src_key_padding_mask)
|
||||
)
|
||||
x = self.norm2(x + self._ff_block(x))
|
||||
|
||||
return x
|
||||
|
||||
def _sa_block(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
) -> Tensor:
|
||||
x = self.self_attn(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
)
|
||||
return x
|
||||
|
||||
# feed forward block
|
||||
def _ff_block(self, x: Tensor) -> Tensor:
|
||||
return self.linear2(self.activation(self.linear1(x)))
|
||||
|
||||
@ -46,6 +46,7 @@ import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.higgsv2.model
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.hunyuan_foley.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1350,6 +1351,10 @@ class ACEStep(BaseModel):
|
||||
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class HunyuanFoley(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.hunyuan_foley.model.HunyuanVideoFoley):
|
||||
super().__init__(model_config, model_type, device, unet_model)
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
@ -384,6 +384,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}triple_blocks.17.audio_cross_q.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Foley
|
||||
return {}
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
|
||||
@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.hunyuan_foley.vae
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
@ -468,6 +469,12 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
# Hunyuan Foley
|
||||
elif "syncformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd:
|
||||
self.latent_dim = 128
|
||||
self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae()
|
||||
# TODO
|
||||
self.memory_used_encode = lambda shape, dtype: shape[0] * model_management.dtype_size(dtype)
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
|
||||
@ -4,6 +4,7 @@ from . import utils
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
import comfy.clap_model
|
||||
import comfy.text_encoders.sd2_clip
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.sa_t5
|
||||
@ -1266,6 +1267,21 @@ class Omnigen2(supported_models_base.BASE):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
class HunyuanFoley(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_foley",
|
||||
}
|
||||
|
||||
latent_format = latent_formats.HunyuanFoley
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
vae_key_prefix = ["dac."]
|
||||
text_encoder_key_prefix = ["clap."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.HunyuanFoley(self, device=device)
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.clap_model.ClapLargeTokenizer, comfy.clap_model.ClapTextEncoderModel)
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
||||
@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
def get_duration(self, retrun_frames = False) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -100,14 +100,16 @@ class VideoFromFile(VideoInput):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
return float(container.duration / av.time_base) if not retrun_frames\
|
||||
else float(container.duration / av.time_base), float(container.duration)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
return float(video_stream.frames / video_stream.average_rate) if not retrun_frames\
|
||||
else float(video_stream.frames / video_stream.average_rate), float(video_stream.frames)
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
@ -117,7 +119,8 @@ class VideoFromFile(VideoInput):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
return float(frame_count / video_stream.average_rate) if not retrun_frames\
|
||||
else float(frame_count / video_stream.average_rate), float(frame_count)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
|
||||
53
comfy_extras/nodes_hunyuan_foley.py
Normal file
53
comfy_extras/nodes_hunyuan_foley.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
class EmptyLatentHunyuanFoley:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"length": ("INT", {"default": 12, "min": 1, "max": 15, "tooltip": "The length of the audio. The same length as the video."}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent audios in the batch."}),
|
||||
},
|
||||
"optional": {"video": ("VIDEO")}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, length, batch_size, video = None):
|
||||
if video is not None:
|
||||
_, length = video.get_duration(return_frames = True)
|
||||
length /= 25
|
||||
shape = (batch_size, 128, int(50 * length))
|
||||
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan_foley"}, )
|
||||
|
||||
class HunyuanFoleyConditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"video_encoding_siglip": ("CONDITIONING",),
|
||||
"video_encoding_synchformer": ("CONDITIONING",),
|
||||
"text_encoding": ("CONDITIONING",)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, video_encoding_1, video_encoding_2, text_encoding):
|
||||
embeds = torch.cat([video_encoding_1, video_encoding_2, text_encoding], dim = 0)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HunyuanFoleyConditioning": HunyuanFoleyConditioning,
|
||||
"EmptyLatentHunyuanFoley": EmptyLatentHunyuanFoley,
|
||||
}
|
||||
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import av
|
||||
import torch
|
||||
import folder_paths
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
@ -14,6 +16,116 @@ from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy.cli_args import args
|
||||
|
||||
class EncodeVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EncodeVideo",
|
||||
display_name="Encode Video",
|
||||
category="image/video",
|
||||
description="Encode a video using an image encoder.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to be encoded."),
|
||||
io.Int.Input(
|
||||
"processing_batch_size", default=-1, min=-1,
|
||||
tooltip=(
|
||||
"Number of frames/segments to process at a time during encoding.\n"
|
||||
"-1 means process all at once. Smaller values reduce GPU memory usage."
|
||||
),
|
||||
),
|
||||
io.Int.Input("step_size", default=8, min=1, max=32,
|
||||
tooltip=(
|
||||
"Stride (in frames) between the start of consecutive segments.\n"
|
||||
"Smaller step = more overlap and smoother temporal coverage "
|
||||
"but higher compute cost. Larger step = faster but may miss detail."
|
||||
),
|
||||
),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.ClipVision.Input("clip_vision", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="encoded_video"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||
b, t, c, h, w = video.shape
|
||||
batch_size = b * t
|
||||
|
||||
if vae is None and clip_vision is None:
|
||||
raise ValueError("Must either have vae or clip_vision.")
|
||||
vae = vae if vae is not None else clip_vision
|
||||
|
||||
if hasattr(vae.first_stage_model, "video_encoding"):
|
||||
data, num_segments, output_fn = vae.video_encoding(video, step_size)
|
||||
batch_size = b * num_segments
|
||||
else:
|
||||
data = video.view(batch_size, c, h, w)
|
||||
output_fn = lambda x: x.view(b, t, -1)
|
||||
|
||||
if processing_batch_size != -1:
|
||||
batch_size = processing_batch_size
|
||||
|
||||
outputs = []
|
||||
total = data.shape[0]
|
||||
for i in range(0, total, batch_size):
|
||||
chunk = data[i : i + batch_size]
|
||||
out = vae.encode(chunk)
|
||||
outputs.append(out)
|
||||
|
||||
output = torch.cat(outputs)
|
||||
|
||||
return output_fn(output)
|
||||
|
||||
class ResampleVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ResampleVideo",
|
||||
display_name="Resample Video",
|
||||
category="image/video",
|
||||
inputs = [
|
||||
io.Video.Input("video"),
|
||||
io.Int.Input("target_fps")
|
||||
],
|
||||
outputs=[io.Image.Output(display_name="images")]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, container: av.container.InputContainer, target_fps: int):
|
||||
# doesn't support upsampling
|
||||
|
||||
stream = container.streams.video[0]
|
||||
frames = []
|
||||
|
||||
src_rate = stream.average_rate or stream.guessed_rate
|
||||
src_fps = float(src_rate) if src_rate else None
|
||||
|
||||
# yield original frames if asked for upsampling or src is unknown
|
||||
if src_fps is None or target_fps > src_fps:
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
||||
frames.append(arr)
|
||||
return torch.stack(frames)
|
||||
|
||||
stream.thread_type = "AUTO"
|
||||
|
||||
next_time = 0.0
|
||||
step = 1.0 / target_fps
|
||||
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
if frame.time is None:
|
||||
continue
|
||||
t = frame.time
|
||||
while t >= next_time:
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
||||
frames.append(arr)
|
||||
next_time += step
|
||||
|
||||
return torch.stack(frames)
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -212,6 +324,8 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
EncodeVideo,
|
||||
ResampleVideo
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user