This commit is contained in:
Yousef Rafat 2025-09-27 13:17:20 +03:00
parent fee1e57ea9
commit 12824eac0d
15 changed files with 2727 additions and 5 deletions

356
comfy/clap_model.py Normal file
View File

@ -0,0 +1,356 @@
from typing import Optional, Callable
import torch
from torch import nn
from comfy.text_encoders.bert import (
BertIntermediate as ClapTextIntermediate,
BertOutput as ClapTextOutput,
)
from comfy.ldm.modules.attention import optimized_attention
import os
from comfy import sd1_clip
from transformers import AutoTokenizer
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor],
chunk_size: int,
chunk_dim: int,
*input_tensors,
) -> torch.Tensor:
if chunk_size > 0:
# chunk into tuples and apply forward_fn to each element in the tuple
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
# from modeling_roberta
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
class ClapProjectionLayer(nn.Module):
def __init__(self, hidden_size, projection_dim, device, dtype, operations):
super().__init__()
self.linear1 = operations.Linear(hidden_size, projection_dim, device=device, dtype=dtype)
self.activation = torch.nn.ReLU()
self.linear2 = operations.Linear(projection_dim, projection_dim, device=device, dtype=dtype)
def forward(self, hidden_states):
hidden_states = self.linear1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.linear2(hidden_states)
return hidden_states
# same as RobertaEmbeddings
class ClapTextEmbeddings(nn.Module):
def __init__(self, vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, device, dtype, operations):
super().__init__()
self.word_embeddings = operations.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, device=device, dtype=dtype)
self.position_embeddings = operations.Embedding(max_position_embeddings, hidden_size, device=device, dtype=dtype)
self.token_type_embeddings = operations.Embedding(type_vocab_size, hidden_size, device=device, dtype=dtype)
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.register_buffer(
"position_ids", torch.arange(max_position_embeddings, device=device, dtype=dtype).expand((1, -1)), persistent=True
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=device), persistent=True
)
# End copy
self.padding_idx = pad_token_id
self.position_embeddings = operations.Embedding(
max_position_embeddings, hidden_size, padding_idx=self.padding_idx, device=device, dtype=dtype
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
if input_ids is not None:
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
# same as AlignTextSelfAttention
class ClapTextSelfAttention(nn.Module):
def __init__(self, num_attention_heads, hidden_size, device, dtype, operations):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
self.key = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
self.value = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype)
self.scaling = self.attention_head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
) -> tuple[torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states, value_states = [t.contiguous() for t in (query_states, key_states, value_states)]
attn_output = optimized_attention(query_states, key_states, value_states, self.num_attention_heads, mask = attention_mask, skip_output_reshape=True, skip_reshape=True)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output.reshape(*input_shape, -1).contiguous()
class ClapTextSelfOutput(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, device, dtype, operations):
super().__init__()
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# same as AlignTextAttention
class ClapTextAttention(nn.Module):
def __init__(self, num_attention_heads, hidden_size, layer_norm_eps, device, dtype, operations):
super().__init__()
self.self = ClapTextSelfAttention(num_attention_heads, hidden_size, device=device, dtype=dtype, operations = operations)
self.output = ClapTextSelfOutput(hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs,
)
return self.output(self_outputs, hidden_states)
# same as AlignTextLayer
class ClapTextLayer(nn.Module):
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device, dtype, operations):
super().__init__()
# maybe we could allow chunking dynamically
self.chunk_size_feed_forward = 0
self.seq_len_dim = 1
self.attention = ClapTextAttention(num_attention_heads, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
self.intermediate = ClapTextIntermediate(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations)
self.output = ClapTextOutput(intermediate_size, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_attention_outputs = self.attention(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self_attention_outputs
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
return layer_output
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# same as AlignTextEncoder
class ClapTextEncoder(nn.Module):
def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device, dtype, operations):
super().__init__()
self.layer = nn.ModuleList([ClapTextLayer(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device=device, dtype=dtype, operations=operations)
for _ in range(num_hidden_layers)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
**kwargs,
):
for _, layer_module in enumerate(self.layer):
hidden_states = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
**kwargs,
)
return hidden_states
class ClapTextPooler(nn.Module):
def __init__(self, hidden_size, device, dtype, operations):
super().__init__()
self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class ClapTextModel(nn.Module):
def __init__(self, num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, num_hidden_layers,
device, dtype, operations, add_pooling_layer=True):
super().__init__()
self.embeddings = ClapTextEmbeddings(vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps,
device=device, dtype=dtype, operations=operations,)
self.encoder = ClapTextEncoder(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations,)
self.pooler = ClapTextPooler(hidden_size, device=device, dtype=dtype, operations=operations,) if add_pooling_layer else None
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
if input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
)
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return sequence_output, pooled_output
class ClapTextModelWithProjection(nn.Module):
def __init__(
self,
hidden_size: int = 768,
intermediate_size: int = 3072,
layer_norm_eps: float = 1e-12,
max_position_embeddings: int = 514,
num_attention_heads: int = 12,
num_hidden_layers: int = 12,
projection_dim: int = 512,
type_vocab_size: int = 1,
vocab_size: int = 50265,
pad_token_id: int = 1,
device=None,
dtype=None,
operations=None
):
super().__init__()
self.text_model = ClapTextModel(num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings,
type_vocab_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations)
self.text_projection = ClapProjectionLayer(hidden_size, projection_dim, device=device, dtype=dtype, operations=operations,)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
return text_embeds, text_outputs[0]
class ClapTextEncoderModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 1}, layer_norm_hidden_state=False, model_class=ClapTextModelWithProjection, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ClapLargeTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)

View File

@ -551,3 +551,7 @@ class Hunyuan3Dv2mini(LatentFormat):
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class HunyuanFoley(LatentFormat):
latent_dimensions = 128
latent_channels = 1024

View File

@ -0,0 +1,921 @@
from typing import List, Tuple, Optional, Union, Dict
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from comfy.ldm.modules.attention import optimized_attention as attention
from comfy.ldm.aura.mmdit import TimestepEmbedder as TimestepEmbedderParent
from comfy.ldm.hydit.posemb_layers import get_1d_rotary_pos_embed
from typing import Union, Tuple
# to get exact matching results
# only difference is the upscale to float32
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
class TimestepEmbedder(TimestepEmbedderParent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden_dim: int, device, dtype, operations):
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
self.w2 = operations.Linear(hidden_dim, hidden_dim, bias=False, device=device, dtype=dtype)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
ndim = x.ndim
if head_first:
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
class ConditionProjection(nn.Module):
def __init__(self, in_channels, hidden_size, dtype=None, device=None, operations = None):
factory_kwargs = {'dtype': dtype, 'device': device}
super().__init__()
self.linear_1 = operations.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
self.act_1 = nn.SiLU()
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
def forward(self, caption):
return self.linear_2(self.act_1(self.linear_1(caption)))
class PatchEmbed1D(nn.Module):
def __init__(
self,
patch_size=1,
in_chans=768,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
device=None,
operations = None
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.flatten = flatten
self.proj = operations.Conv1d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.transpose(1, 2)
x = self.norm(x)
return x
# avoid classifying as wrapper to work with operations.conv1d
class ChannelLastConv1d(nn.Module):
def __init__(self, in_channels, out_channels, bias=True, kernel_size = 3, padding = 0, device=None, dtype=None, operations=None):
super().__init__()
operations = operations or nn
underlying = operations.Conv1d(
in_channels, out_channels, kernel_size = kernel_size, padding = padding,
bias=bias, device=device, dtype=dtype
)
self.register_parameter("weight", underlying.weight)
if getattr(underlying, "bias", None) is not None:
self.register_parameter("bias", underlying.bias)
else:
self.register_parameter("bias", None)
object.__setattr__(self, "_underlying", underlying)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._underlying(x.permute(0, 2, 1))
return x.permute(0, 2, 1)
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
device=None,
dtype=None,
operations = None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def modulate(x, shift=None, scale=None):
if x.ndim == 3:
shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale)
elif scale is None:
return x + shift
else:
return x * (1 + scale) + shift
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations = None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = nn.SiLU()
self.linear = operations.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(x))
class FinalLayer1D(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels, device=None, dtype=None, operations = None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.linear = operations.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
class MLP(nn.Module):
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
operations = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = (bias, bias)
drop_probs = (drop, drop)
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
return self.drop2(self.fc2(self.norm(self.drop1(self.act(self.fc1(x))))))
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def get_nd_rotary_pos_embed(
rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
freq_scaling=freq_scaling,
)
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1)
sin = torch.cat([emb[1] for emb in embs], dim=1)
return cos, sin
else:
emb = torch.cat(embs, dim=1)
return emb
def apply_gate(x, gate = None):
if gate is None:
return x
if gate.ndim == 2 and x.ndim == 3:
gate = gate.unsqueeze(1)
return x * gate
def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
B, N1, H, C = x1.shape
B, N2, H, C = x2.shape
assert x1.ndim == x2.ndim == 4
if N1 != N2:
x2 = x2.view(B, N2, -1).transpose(1, 2)
x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
x2 = x2.transpose(1, 2).view(B, N1, H, C)
x = torch.stack((x1, x2), dim=2)
x = x.reshape(B, N1 * 2, H, C)
return x
def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
B, N, H, C = x.shape
assert N % 2 == 0 and N // 2 == len1
x = x.reshape(B, -1, 2, H, C)
x1 = x[:, :, 0]
x2 = x[:, :, 1]
if x2.shape[1] != len2:
x2 = x2.view(B, len1, H * C).transpose(1, 2)
x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
x2 = x2.transpose(1, 2).view(B, len2, H, C)
return x1, x2
def apply_modulated_block(x, norm_layer, shift, scale, mlp_layer, gate):
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
return x + apply_gate(mlp_layer(x_mod), gate=gate)
def prepare_self_attn_qkv(x, norm_layer, qkv_layer, q_norm, k_norm, shift, scale, num_heads):
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
qkv = qkv_layer(x_mod)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=num_heads)
q = q_norm(q).to(v)
k = k_norm(k).to(v)
return q, k, v
class TwoStreamCABlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qk_norm: bool = True,
qkv_bias: bool = False,
interleaved_audio_visual_rope: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations = None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
head_dim = hidden_size // num_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
self.audio_mod = ModulateDiT(hidden_size, factor=9, operations = operations, **factory_kwargs)
self.audio_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.audio_self_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
def make_qk_norm(name: str):
layer = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
setattr(self, name, layer)
for name in ["v_cond_attn_q_norm", "v_cond_attn_k_norm", "audio_cross_q_norm",
"v_cond_cross_q_norm", "text_cross_k_norm", "audio_self_q_norm", "audio_self_k_norm"]:
make_qk_norm(name)
self.audio_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.v_cond_mod = ModulateDiT(hidden_size, factor = 9, operations = operations, **factory_kwargs)
self.v_cond_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
self.v_cond_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.max_text_len = 100
self.rope_dim_list = None
self.audio_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.audio_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.v_cond_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.text_cross_kv = operations.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
self.audio_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.v_cond_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.audio_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.audio_mlp = MLP(
hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs
)
self.v_cond_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.v_cond_mlp = MLP(
hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs
)
def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
target_ndim = 1 # n-d RoPE
rope_sizes = [text_len]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
return text_freqs_cos, text_freqs_sin
def forward(
self,
audio: torch.Tensor,
cond: torch.Tensor,
v_cond: torch.Tensor,
attn_mask: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple = None,
v_freqs_cis: tuple = None,
sync_vec: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
) = self.audio_mod(sync_vec if sync_vec is not None else vec).chunk(9, dim=-1)
(
v_cond_mod1_shift,
v_cond_mod1_scale,
v_cond_mod1_gate,
v_cond_mod2_shift,
v_cond_mod2_scale,
v_cond_mod2_gate,
v_cond_mod3_shift,
v_cond_mod3_scale,
v_cond_mod3_gate,
) = self.v_cond_mod(vec).chunk(9, dim=-1)
audio_q, audio_k, audio_v = prepare_self_attn_qkv(
audio, self.audio_norm1, self.audio_self_attn_qkv,
self.audio_self_q_norm, self.audio_self_k_norm,
audio_mod1_shift, audio_mod1_scale, self.num_heads
)
v_cond_q, v_cond_k, v_cond_v = prepare_self_attn_qkv(
v_cond, self.v_cond_norm1, self.v_cond_attn_qkv,
self.v_cond_attn_q_norm, self.v_cond_attn_k_norm,
v_cond_mod1_shift, v_cond_mod1_scale, self.num_heads
)
# Apply RoPE if needed for audio and visual
if freqs_cis is not None:
if not self.interleaved_audio_visual_rope:
audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
audio_q, audio_k = audio_qq, audio_kk
else:
ori_audio_len = audio_q.shape[1]
ori_v_con_len = v_cond_q.shape[1]
interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
)
audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
)
audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
)
audio_q, audio_k = audio_qq, audio_kk
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
q = torch.cat((v_cond_q, audio_q), dim=1)
k = torch.cat((v_cond_k, audio_k), dim=1)
v = torch.cat((v_cond_v, audio_v), dim=1)
# TODO: look further into here
if attention.__name__ == "attention_pytorch":
q, k, v = [t.transpose(1, 2) for t in (q, k, v)]
attn = attention(q, k, v, heads = self.num_heads, mask=attn_mask, skip_reshape=True)
v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
head_dim = self.hidden_size // self.num_heads
audio_q = self.prepare_modulated_query(audio, self.audio_norm2, self.audio_cross_q,
self.audio_cross_q_norm, audio_mod2_shift, audio_mod2_scale,
self.num_heads, self.rope_dim_list)
v_cond_q = self.prepare_modulated_query(v_cond, self.v_cond_norm2, self.v_cond_cross_q,
self.v_cond_cross_q_norm, v_cond_mod2_shift, v_cond_mod2_scale,
self.num_heads, self.rope_dim_list)
text_kv = self.text_cross_kv(cond)
text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
text_k = self.text_cross_k_norm(text_k).to(text_v)
text_len = text_k.shape[1]
text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
rope_dim_list=self.rope_dim_list)
text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
if attention.__name__ == "attention_pytorch":
v_cond_audio_q, text_k, text_v = [t.transpose(1, 2) for t in (v_cond_audio_q, text_k, text_v)]
cross_attn = attention(v_cond_audio_q, text_k, text_v, self.num_heads, skip_reshape = True)
v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
audio = apply_modulated_block(audio, self.audio_norm3, audio_mod3_shift, audio_mod3_scale, self.audio_mlp, audio_mod3_gate)
v_cond = apply_modulated_block(v_cond, self.v_cond_norm3, v_cond_mod3_shift, v_cond_mod3_scale, self.v_cond_mlp, v_cond_mod3_gate)
return audio, cond, v_cond
def prepare_modulated_query(self, x, norm_layer, q_layer, q_norm_layer, shift, scale, num_heads, rope_dim_list):
x_mod = modulate(norm_layer(x), shift=shift, scale=scale)
q = q_layer(x_mod)
q = rearrange(q, "B L (H D) -> B L H D", H=num_heads)
q = q_norm_layer(q)
head_dim = q.shape[-1]
freqs_cos, freqs_sin = self.build_rope_for_text(q.shape[1], head_dim, rope_dim_list)
freqs_cis = (freqs_cos.to(q.device), freqs_sin.to(q.device))
q = apply_rotary_emb(q, q, freqs_cis, head_first=False)[0]
return q
class SingleStreamBlock(nn.Module):
def __init__(self, hidden_size: int,
num_heads: int,
mlp_ratio: float,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations = None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.modulation = ModulateDiT(
hidden_size=hidden_size,
factor=6,
operations = operations,
**factory_kwargs,
)
self.linear_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=True)
self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, operations = operations, **factory_kwargs)
self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, operations = operations, **factory_kwargs)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs)
self.q_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs)
self.k_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs)
self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
def forward(self, x: torch.Tensor, cond: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
modulation = self.modulation(cond)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
qkv = self.linear_qkv(x_norm1)
q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
q, k, v = [t.squeeze(-1) for t in (q, k, v)]
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
q, k, v = [t.contiguous() for t in (q, k, v)]
out = attention(q, k, v, self.num_heads, skip_output_reshape = True, skip_reshape = True)
out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
x = x + apply_gate(self.linear1(out),gate=gate_msa)
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
return x
class HunyuanVideoFoley(nn.Module):
def __init__(
self,
model_args,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations = None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.depth_triple_blocks = 18
self.depth_single_blocks = 36
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", True)
self.condition_dim = model_args.get("condition_dim", 768)
self.patch_size = model_args.get("patch_size", 1)
self.visual_in_channels = model_args.get("clip_dim", 768)
self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
self.out_channels = self.audio_vae_latent_dim
self.unpatchify_channels = self.out_channels
self.num_heads = model_args.get("num_heads", 12)
self.hidden_size = model_args.get("hidden_size", 1536)
self.rope_dim_list = model_args.get("rope_dim_list", None)
self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
self.qkv_bias = model_args.get("qkv_bias", True)
self.qk_norm = model_args.get("qk_norm", True)
# sync condition things
self.sync_modulation = model_args.get("sync_modulation", False)
self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", True)
self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
self.clip_len = model_args.get("clip_length", 64)
self.sync_len = model_args.get("sync_length", 192)
self.patch_size = 1
self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, operations=operations, **factory_kwargs)
self.visual_proj = SwiGLU(dim = self.visual_in_channels, hidden_dim = self.hidden_size, device=device, dtype=dtype, operations=operations)
self.cond_in = ConditionProjection(
self.condition_dim, self.hidden_size, operations=operations, **factory_kwargs
)
self.time_in = TimestepEmbedder(self.hidden_size, operations = operations, **factory_kwargs)
# visual sync embedder if needed
if self.sync_in_ksz == 1:
sync_in_padding = 0
elif self.sync_in_ksz == 3:
sync_in_padding = 1
else:
raise ValueError
if self.sync_modulation or self.add_sync_feat_to_audio:
self.sync_in = nn.Sequential(
operations.Linear(self.sync_feat_dim, self.hidden_size, **factory_kwargs),
nn.SiLU(),
ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding, operations=operations, **factory_kwargs),
)
self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim), **factory_kwargs))
self.triple_blocks = nn.ModuleList(
[
TwoStreamCABlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qk_norm=self.qk_norm,
qkv_bias=self.qkv_bias,
interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
operations=operations,
**factory_kwargs,
)
for _ in range(self.depth_triple_blocks)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
operations=operations,
**factory_kwargs,
)
for _ in range(self.depth_single_blocks)
]
)
self.final_layer = FinalLayer1D(
self.hidden_size, self.patch_size, self.out_channels, operations = operations,**factory_kwargs
)
self.unpatchify_channels = self.out_channels
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.clip_len
if bs is None:
return self.empty_clip_feat.expand(len, -1) # 15s
else:
return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.sync_len
if bs is None:
return self.empty_sync_feat.expand(len, -1)
else:
return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
target_ndim = 1 # n-d RoPE
rope_sizes = [audio_emb_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
target_ndim = 1
rope_sizes = [visual_cond_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
freq_scaling=1.0 * audio_emb_len / visual_cond_len,
)
return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
def build_rope_for_interleaved_audio_visual(self, total_len):
target_ndim = 1 # n-d RoPE
rope_sizes = [total_len]
head_dim = self.hidden_size // self.num_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list=rope_dim_list,
start=rope_sizes,
theta=10000,
use_real=True,
theta_rescale_factor=1.0,
)
return freqs_cos, freqs_sin
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
full_cond: torch.Tensor,
transformer_options = {},
drop_visual: Optional[List[bool]] = None,
):
audio = x
bs, _, ol = x.shape
tl = ol // self.patch_size
condition, uncondition = torch.chunk(2, full_cond)
uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition)
clip_feat, sync_feat, cond = torch.chunk(3, condition)
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond])
if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
vec = self.time_in(t)
sync_vec = None
if self.add_sync_feat_to_audio:
sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb.to(sync_feat.device)
sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim)
sync_feat = self.sync_in.to(sync_feat.device)(sync_feat)
add_sync_feat_to_audio = (
F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
)
cond = self.cond_in(cond)
cond_seq_len = cond.shape[1]
audio = self.audio_embedder(x)
audio_seq_len = audio.shape[1]
v_cond = self.visual_proj(clip_feat)
v_cond_seq_len = v_cond.shape[1]
attn_mask = None
freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
v_freqs_cos = v_freqs_sin = None
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
if self.add_sync_feat_to_audio:
add_sync_layer = 0
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
def block_wrap(**kwargs):
return block(**kwargs)
for layer_num, block in enumerate(self.triple_blocks):
if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
audio = audio + add_sync_feat_to_audio
triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
if ("triple_block", layer_num) in blocks_replace:
audio, cond, v_cond = blocks_replace[("triple_block", layer_num)]({
"audio": triple_block_args[0],
"cond": triple_block_args[1],
"v_cond": triple_block_args[2],
"attn_mask": triple_block_args[3],
"vec": triple_block_args[4],
"freqs_cis": triple_block_args[5],
"v_freqs_cis": triple_block_args[6],
"sync_vec": triple_block_args[7]
}, {"original_block": block_wrap})
else:
audio, cond, v_cond = block(*triple_block_args)
x = audio
if sync_vec is not None:
vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
vec = torch.cat((vec, sync_vec), dim=1)
freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
if self.add_sync_feat_to_audio:
vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
if len(self.single_blocks) > 0:
for layer_num, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
(freqs_cos, freqs_sin),
]
if ("single_block", layer_num) in blocks_replace:
x = blocks_replace[("single_block", layer_num)]({
"x": single_block_args[0],
"vec": single_block_args[1],
"freqs_cis": single_block_args[2]
}, {"original_block": block_wrap})
else:
x = block(*single_block_args)
audio = x
if sync_vec is not None:
vec = sync_vec
audio = self.final_layer(audio, vec)
audio = self.unpatchify1d(audio, tl)
uncond, cond = torch.chunk(2, audio)
return torch.cat([cond, uncond])
def unpatchify1d(self, x, l):
c = self.unpatchify_channels
p = self.patch_size
x = x.reshape(shape=(x.shape[0], l, p, c))
x = torch.einsum("ntpc->nctp", x)
audio = x.reshape(shape=(x.shape[0], c, l * p))
return audio

View File

@ -0,0 +1,991 @@
#!/usr/bin/env python3
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange, repeat, einops
from comfy.ldm.hunyuan3dv2_1.hunyuandit import MLP as Mlp
from transformers.modeling_outputs import BaseModelOutputWithPooling
from comfy.ldm.modules.attention import optimized_attention, TransformerEncoderComfyv
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
from typing import Optional, Union, Tuple
class Config:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, device=None, dtype=None, operations=None):
super().__init__()
img_size = img_size if type(img_size) is tuple else (img_size, img_size)
patch_size = img_size if type(patch_size) is tuple else (patch_size, patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
def forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PatchEmbed3D(nn.Module):
def __init__(
self,
img_size=224,
in_chans=3,
patch_size=16,
z_block_size=2,
embed_dim=768,
flatten=True,
device=None, dtype=None, operations=None
):
super().__init__()
self.height = img_size // patch_size
self.width = img_size // patch_size
self.z_block_size = z_block_size
self.proj = operations.Conv3d(
in_chans,
embed_dim,
kernel_size=(z_block_size, patch_size, patch_size),
stride=(z_block_size, patch_size, patch_size),
device=device, dtype=dtype
)
self.flatten = flatten
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2)
return x
def qkv_attn(q, k, v, heads):
bh, seq_q, dim_head = q.shape
b = bh // heads
# (b*heads, seq, dim) -> (b, heads, seq, dim)
q2 = q.view(b, heads, seq_q, dim_head)
k2 = k.view(b, heads, k.shape[1], dim_head)
v2 = v.view(b, heads, v.shape[1], dim_head)
out = optimized_attention(q2, k2, v2, heads=heads, skip_reshape=True)
out = out.permute(0, 2, 1, 3).contiguous().view(b * heads, seq_q, dim_head)
return out
class DividedAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
h = self.num_heads
q, k, v = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
cls_out = qkv_attn(cls_q, k, v, self.num_heads)
q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_))
r = q_.shape[0] // cls_k.shape[0]
cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v))
k_ = torch.cat((cls_k, k_), dim=1)
v_ = torch.cat((cls_v, v_), dim=1)
out = qkv_attn(q_, k_, v_, self.num_heads)
out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims)
out = torch.cat((cls_out, out), dim=1)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
x = self.proj(out)
return x
class DividedSpaceTimeBlock(nn.Module):
def __init__(
self,
dim=768,
num_heads=12,
qkv_bias=False,
norm_layer=nn.LayerNorm,
device = None, dtype = None, operations = None
):
super().__init__()
factory_kwargs = {"device":device, "dtype": dtype}
self.einops_from_space = "b (f n) d"
self.einops_to_space = "(b f) n d"
self.einops_from_time = "b (f n) d"
self.einops_to_time = "(b n) f d"
self.norm1 = norm_layer(dim)
self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, operations = operations, **factory_kwargs)
self.timeattn = DividedAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, operations=operations, **factory_kwargs
)
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(width = dim, operations = operations, device=device, dtype=dtype)
self.norm3 = norm_layer(dim)
def forward(self, x, seq_len=196, num_frames=8, tok_mask: torch.Tensor = None):
time_output = self.timeattn(
self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask
)
time_residual = x + time_output
space_output = self.attn(
self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask
)
space_residual = time_residual + self.drop_path(space_output)
x = space_residual
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class MotionFormer(nn.Module):
def __init__(self, device = None, dtype = None, operations = None):
super().__init__()
self.APPROX_ATTN_TYPE = "none"
self.APPROX_ATTN_DIM = 64
self.img_size = 224
self.patch_size = 16
self.in_chans = 3
self.num_classes = 174
self.embed_dim = 768
self.depth = 12
self.num_heads = 12
self.mlp_ratio = 4
self.qkv_bias = True
self.drop_rate = 0.0
self.drop_path_rate = 0.2
self.temporal_resolution = 8
self.use_mlp = True
self.num_features = self.embed_dim
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.attn_drop_rate = 0.0
self.factorize_space_time = True
# Patch Embedding
self.patch_embed = PatchEmbed(
img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim,
device=device, dtype=dtype, operations=operations
)
# 3D Patch Embedding
self.patch_embed_3d = PatchEmbed3D(
img_size=self.img_size,
patch_size=self.patch_size,
in_chans=self.in_chans,
embed_dim=self.embed_dim,
z_block_size = 2,
device=device, dtype=dtype, operations=operations
)
self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data)
# Number of patches
self.num_patches = self.patch_embed.num_patches * self.temporal_resolution
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, device=device, dtype=dtype))
# Positional embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim, device=device, dtype=dtype))
self.pos_drop = nn.Dropout(p=0.0)
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim, device=device, dtype=dtype))
self.blocks = nn.ModuleList(
[
DividedSpaceTimeBlock(
dim=self.embed_dim,
num_heads=self.num_heads,
qkv_bias=self.qkv_bias,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations
)
for _ in range(self.depth)
]
)
self.norm = norm_layer(self.embed_dim)
self.pre_logits = nn.Identity()
transf_enc_layer_kwargs = dict(
d_model=self.embed_dim,
nhead=self.num_heads,
activation=nn.GELU(),
batch_first=True,
dim_feedforward=self.mlp_ratio * self.embed_dim,
dropout=self.drop_rate,
layer_norm_eps=1e-6,
norm_first=True,
)
self.spatial_attn_agg = SpatialTransformerEncoderLayer(device = device, dtype=dtype, operations=operations,**transf_enc_layer_kwargs)
self.temp_attn_agg = nn.Identity()
def forward_features(self, x):
B = x.shape[0]
# apply patching on input
x = self.patch_embed_3d(x)
tok_mask = None
# Append CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
new_pos_embed = self.pos_embed
npatch = self.patch_embed.num_patches
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
total_pos_embed = tile_pos_embed + tile_temporal_embed
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
x = x + total_pos_embed
# Apply positional dropout
x = self.pos_drop(x)
# Encoding using transformer layers
for i, blk in enumerate(self.blocks):
x = blk(
x,
seq_len=npatch,
num_frames=self.temporal_resolution,
tok_mask=tok_mask,
)
return x, tok_mask
def forward(self, x):
B, S, C, T, H, W = x.shape
orig_shape = (B, S, C, T, H, W)
x = x.view(B * S, C, T, H, W) # flatten batch and segments
x = self.forward_segments(x, orig_shape=orig_shape)
x = x.view(B, S, *x.shape[1:])
return x
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
x, x_mask = self.forward_features(x)
x = x[:, 1:, :]
x = self.norm(x)
x = self.pre_logits(x)
if self.factorize_space_time:
x = self.restore_spatio_temp_dims(x, orig_shape)
x = self.spatial_attn_agg(x, x_mask)
x = self.temp_attn_agg(x)
return x
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
B, S, C, T, H, W = orig_shape
D = self.embed_dim
# num patches in each dimension
t = T // self.patch_embed_3d.z_block_size
h = self.patch_embed_3d.height
w = self.patch_embed_3d.width
feats = feats.permute(0, 2, 1) # (B*S, D, T)
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
return feats
class BaseEncoderLayer(TransformerEncoderComfyv):
def __init__(
self,
add_pos_emb: bool = False,
pos_emb_drop: float = None,
pos_max_len: int = None,
device = None,
dtype = None, operations = None,
*args, **kwargs
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(operations = operations, *args, **kwargs, **factory_kwargs)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim, **factory_kwargs))
self.add_pos_emb = add_pos_emb
if add_pos_emb:
self.pos_max_len = 1 + pos_max_len
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim, **factory_kwargs))
self.pos_drop = nn.Dropout(pos_emb_drop)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
batch_dim = x.shape[0]
cls_tokens = self.cls_token.expand(batch_dim, -1, -1)
x = torch.cat((cls_tokens, x), dim=-2)
if x_mask is not None:
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device)
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1)
B, N = x_mask_w_cls.shape
x_mask_w_cls = (
x_mask_w_cls.reshape(B, 1, 1, N)
.expand(-1, self.self_attn.num_heads, N, -1)
.reshape(B * self.self_attn.num_heads, N, N)
)
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool"
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
else:
x_mask_w_cls = None
# add positional embedding
if self.add_pos_emb:
seq_len = x.shape[1]
assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})"
x = x + self.pos_emb[:, :seq_len, :]
x = self.pos_drop(x)
x = super().forward(src=x, src_mask=x_mask_w_cls)
x = x[:, 0, :]
return x
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
BS, D, t, h, w = x.shape
x = rearrange(x, "BS D t h w -> (BS t) (h w) D")
if x_mask is not None:
x_mask = rearrange(x_mask, "BS t h w -> (BS t) (h w)")
x = super().forward(x=x, x_mask=x_mask)
x = rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t)
return x
class AST(torch.nn.Module):
def __init__(
self,
max_spec_t: int = None,
factorize_freq_time: bool = None,
max_segments: int = None,
device = None, dtype = None, operations = None
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.extract_features = True
self.max_spec_t = max_spec_t
self.max_segments = max_segments
self.config = ASTConfig()
self.config.num_labels = 527
self.ast = ASTModel(self.config, device=device, dtype=dtype, operations=operations)
self.feat_type = "last_hidden_state"
self.factorize_freq_time = factorize_freq_time
transf_enc_layer_kwargs = dict(
d_model=self.config.hidden_size,
nhead=self.config.num_attention_heads,
dim_feedforward=self.config.intermediate_size,
activation=torch.nn.GELU(),
batch_first=True,
dropout=self.config.attention_probs_dropout_prob,
layer_norm_eps=1e-6,
norm_first=True,
)
if factorize_freq_time:
self.feat_type = "last_hidden_state"
self.freq_attn_agg = FrequencyTransformerEncoderLayer(operations = operations, **transf_enc_layer_kwargs, **factory_kwargs)
self.temp_attn_agg = torch.nn.Identity()
self.device = device
self.patch_position_emb()
def forward(
self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs
) -> torch.Tensor:
B, S, T, F = x.shape
if for_loop:
assert cont_mask is None, "cont_mask is not supported with for_loop=True"
orig_shape_s = (B, 1, T, F)
x = torch.cat(
[self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1
)
else:
orig_shape = (B, S, T, F)
x = x.view(B * S, T, F)
if cont_mask is not None:
cont_mask = cont_mask.reshape(B * S, T, F)
x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
x = x.view(B, S, *x.shape[1:])
global_x = None
return x, global_x
def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
if self.extract_features:
x = self.get_features_by_type(x)
if self.factorize_freq_time:
x = self.restore_freq_temp_dims(x, orig_shape)
if cont_mask is not None:
x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
x_mask = self.restore_freq_temp_dims(x_mask, orig_shape)
x_mask = x_mask[:, 0, :, :]
else:
x_mask = None
x = self.freq_attn_agg(x, x_mask)
x = self.temp_attn_agg(x)
else:
x = x["pooler_output"]
x = self.classifier(x)
return x
def get_features_by_type(self, x) -> torch.Tensor:
return x["last_hidden_state"] # (B, 2+T, D)
def restore_freq_temp_dims(self, feats, orig_shape: tuple):
B, S, T, F = orig_shape
D = self.config.hidden_size
# num patches in each dimension
f, t = self.ast.embeddings.get_shape(self.config)
if self.feat_type == "last_hidden_state":
feats = feats[:, 2:, :] # removing CLS and distill tokens
feats = feats.permute(0, 2, 1) # (B*S, D, T)
feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
return feats
def patch_position_emb(self):
if self.max_spec_t is not None:
self.config.max_length = self.max_spec_t
f, t = self.ast.embeddings.get_shape(self.config)
shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens
self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
def to(self, device):
self.device = torch.device(device)
return super().to(device)
class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
BS, D, f, t = x.shape
x = x.permute(0, 3, 2, 1)
x = x.reshape(BS * t, f, D)
if x_mask is not None:
x_mask = x_mask.permute(0, 2, 1)
x_mask = x_mask.reshape(BS * t, f)
x = super().forward(x=x, x_mask=x_mask)
x = x.view(BS, t, D)
return x
class ASTEmbeddings(nn.Module):
def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype))
self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype))
self.patch_embeddings = ASTPatchEmbeddings(config, device, dtype, operations)
frequency_out_dimension, time_out_dimension = self.get_shape(config)
num_patches = frequency_out_dimension * time_out_dimension
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size, device=device, dtype=dtype))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def get_shape(self, config):
frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
return frequency_out_dimension, time_out_dimension
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
batch_size = input_values.shape[0]
embeddings = self.patch_embeddings(input_values)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class ASTPatchEmbeddings(nn.Module):
def __init__(self, config, device = None, dtype = None, operations = None):
super().__init__()
patch_size = config.patch_size
frequency_stride = config.frequency_stride
time_stride = config.time_stride
self.projection = operations.Conv2d(
1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride), device = device, dtype = dtype
)
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
input_values = input_values.unsqueeze(1)
input_values = input_values.transpose(2, 3)
embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
return embeddings
class ASTSelfAttention(nn.Module):
def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None:
super().__init__()
factory_kwargs = { "device": device, "dtype": dtype }
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
self.key = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
self.value = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
tok_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
if tok_mask is not None:
attn_mask = (tok_mask == 0)
attn_mask = attn_mask[:, None, None, :]
else:
attn_mask = None
context_layer = optimized_attention(query_layer, key_layer, value_layer, self.num_attention_heads, mask = attn_mask, skip_output_reshape=True, skip_reshape=True)
context_layer = context_layer.view(*query_layer.size())
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return (context_layer,)
class ASTSelfOutput(nn.Module):
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.dense = operations.Linear(config.hidden_size, config.hidden_size, device=device, dtype=dtype)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class ASTAttention(nn.Module):
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
super().__init__()
self.attention = ASTSelfAttention(config, device=device, dtype=dtype, operations=operations)
self.output = ASTSelfOutput(config, device=device, dtype=dtype, operations=operations)
def forward(
self,
hidden_states: torch.Tensor,
tok_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, tok_mask, head_mask)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
super().__init__()
self.dense = operations.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
self.intermediate_act_fn = nn.GELU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ASTOutput(nn.Module):
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
super().__init__()
self.dense = operations.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ASTLayer(nn.Module):
def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None:
super().__init__()
factory_kwargs = {"device":device, "dtype":dtype}
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ASTAttention(config, operations = operations, **factory_kwargs)
self.intermediate = ASTIntermediate(config, operations=operations, **factory_kwargs)
self.output = ASTOutput(config, operations=operations, **factory_kwargs)
self.layernorm_before = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs)
self.layernorm_after = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs)
def forward(
self,
hidden_states: torch.Tensor,
tok_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states),
tok_mask,
head_mask,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
class ASTEncoder(nn.Module):
def __init__(self, config: ASTConfig, device, dtype, operations) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([ASTLayer(config, device, dtype, operations) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states: torch.Tensor,
tok_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
return hidden_states
class ASTModel(nn.Module):
def __init__(self, config: ASTConfig, device, dtype, operations):
super().__init__()
self.config = config
self.embeddings = ASTEmbeddings(config, device, dtype, operations)
self.encoder = ASTEncoder(config, device, dtype, operations)
self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
def get_input_embeddings(self) -> ASTPatchEmbeddings:
return self.embeddings.patch_embeddings
def forward(
self,
input_values: Optional[torch.Tensor] = None,
cont_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_values)
if cont_mask is not None:
indicator = torch.ones_like(input_values).to(input_values.dtype)
indicator[~cont_mask] = torch.inf
with torch.no_grad():
indicator = self.embeddings(indicator)
tok_mask = ~torch.isnan(indicator)
tok_mask = tok_mask[:, :, 0]
else:
tok_mask = None
encoder_outputs = self.encoder(
embedding_output,
tok_mask=tok_mask,
head_mask=head_mask,
)
sequence_output = encoder_outputs
sequence_output = self.layernorm(sequence_output)
pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
return (
BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
),
tok_mask,
)
class ASTMLPHead(nn.Module):
def __init__(self, config: ASTConfig, device, dtype, operations):
super().__init__()
self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
self.dense = operations.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
def forward(self, hidden_state):
hidden_state = self.layernorm(hidden_state)
hidden_state = self.dense(hidden_state)
return hidden_state
class RandInitPositionalEncoding(nn.Module):
def __init__(self, block_shape: list, n_embd: int, device = None, dtype = None,):
super().__init__()
self.block_shape = block_shape
self.n_embd = n_embd
self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd, device=device, dtype=dtype))
def forward(self, token_embeddings):
return token_embeddings + self.pos_emb
class GlobalTransformer(torch.nn.Module):
def __init__(
self,
tok_pdrop=0.0,
embd_pdrop=0.1,
resid_pdrop=0.1,
attn_pdrop=0.1,
n_layer=3,
n_head=8,
n_embd=768,
pos_emb_block_shape=[
198,
],
n_off_head_out=21,
device = None, dtype = None, operations = None
) -> None:
super().__init__()
factory_kwargs = {"device":device, "dtype": dtype}
self.config = Config(
embd_pdrop=embd_pdrop,
resid_pdrop=resid_pdrop,
attn_pdrop=attn_pdrop,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
)
# input norm
self.vis_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
self.aud_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
# aux tokens
self.OFF_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
self.MOD_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
# whole token dropout
self.tok_pdrop = tok_pdrop
self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
# maybe add pos emb
self.pos_emb_cfg = RandInitPositionalEncoding(
block_shape=pos_emb_block_shape,
n_embd=n_embd,
)
# the stem
self.drop = torch.nn.Dropout(embd_pdrop)
self.blocks = operations.Sequential(*[Block(self.config, operations=operations, **factory_kwargs) for _ in range(n_layer)])
# pre-output norm
self.ln_f = operations.LayerNorm(n_embd)
# maybe add a head
self.off_head = operations.Linear(in_features=n_embd, out_features=n_off_head_out)
def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
B, Sv, D = v.shape
B, Sa, D = a.shape
off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
if self.tok_pdrop > 0:
v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
x = torch.cat((off_tok, v, mod_tok, a), dim=1)
if hasattr(self, "pos_emb_cfg"):
x = self.pos_emb_cfg(x)
x = self.drop(x)
x = self.blocks(x)
x = self.ln_f(x)
if attempt_to_apply_heads and hasattr(self, "off_head"):
x = self.off_head(x[:, 0, :])
return x
class SelfAttention(nn.Module):
def __init__(self, config, device, dtype, operations):
super().__init__()
self.key = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
self.query = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
self.value = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
self.proj = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype)
self.n_head = config.n_head
def forward(self, x):
B, T, C = x.size()
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = optimized_attention(q, k, v, self.n_head, skip_reshape=True)
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
def __init__(self, config, device, dtype, operations):
super().__init__()
factory_kwargs = {"device":device, "dtype":dtype}
self.ln1 = operations.LayerNorm(config.n_embd, **factory_kwargs)
self.ln2 = operations.LayerNorm(config.n_embd, **factory_kwargs)
self.attn = SelfAttention(config, device, dtype, operations)
self.mlp = nn.Sequential(
operations.Linear(config.n_embd, 4 * config.n_embd, **factory_kwargs),
nn.GELU(),
operations.Linear(4 * config.n_embd, config.n_embd, **factory_kwargs),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class Synchformer(nn.Module):
def __init__(self, device, dtype, operations):
super().__init__()
factory_kwargs = {"device":device, "dtype":dtype}
self.vfeat_extractor = MotionFormer(operations = operations, **factory_kwargs)
self.afeat_extractor = AST(
operations = operations,
max_spec_t = 66,
factorize_freq_time = True,
**factory_kwargs
)
self.vproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs)
self.aproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs)
self.transformer = GlobalTransformer(
tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768, operations=operations, **factory_kwargs
)
def forward(self, vis):
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
vis = self.vfeat_extractor(vis)
return vis
def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
vis = self.vproj(vis)
aud = self.aproj(aud)
B, S, tv, D = vis.shape
B, S, ta, D = aud.shape
vis = vis.view(B, S * tv, D)
aud = aud.view(B, S * ta, D)
logits = self.transformer(vis, aud)
return logits
def extract_vfeats(self, vis):
return self.vfeat_extractor(vis.permute(0, 1, 3, 2, 4, 5))
def extract_afeats(self, aud):
B, S, _, Fa, Ta = aud.shape
aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2)
aud, _ = self.afeat_extractor(aud)
return aud

View File

@ -0,0 +1,85 @@
import torch
import numpy as np
from typing import List
from einops import rearrange
from torchvision.transforms import v2
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder
import comfy.ops
ops = comfy.ops.disable_weight_init
class DAC(torch.nn.Module):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
sample_rate: int = 44100,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
self.decoder = DACDecoder(
latent_dim,
decoder_dim,
decoder_rates,
operations = ops
)
self.sample_rate = sample_rate
def decode(self, z: torch.Tensor):
return self.decoder(z)
def forward(self):
pass
class FoleyVae(torch.nn.Module):
def __init__(self):
self.dac = DAC()
self.syncformer = Synchformer(None, None, operations = ops)
self.syncformer_preprocess = v2.Compose(
[
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
def decode(self, x, vae_options = {}):
return self.dac.decode(x)
def encode(self, x):
return self.syncformer(x)
def video_encoding(self, video, step: int):
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video).permute(0, 3, 1, 2)
video = self.syncformer_preprocess(video).unsqueeze(0)
seg_len = 16
t = video.size(1)
nseg = max(0, (t - seg_len) // step + 1)
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
data = torch.stack(clips, dim=1)
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))

View File

@ -158,7 +158,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, freq_scaling = 1.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@ -180,6 +180,10 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
if freq_scaling != 1.0:
freqs *= freq_scaling
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:

View File

@ -2,6 +2,7 @@ import math
import sys
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
@ -1032,4 +1033,162 @@ class SpatialVideoTransformer(SpatialTransformer):
out = x + x_in
return out
# comfyui implementation of nn.MultiheadAttention
class MultiheadAttentionComfyv(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
bias=True,
batch_first=False,
device=None,
dtype=None,
operations = None
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
# to avoid pytorch checkpoint registeration
object.__setattr__(self, "_q_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
object.__setattr__(self, "_k_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
object.__setattr__(self, "_v_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs))
self.out_proj = operations.Linear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self.num_heads = num_heads
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
self.embed_dim = embed_dim
# overwriting state dict loading to convert in_proj_weight/bias -> self._q_proj/_k_proj/_v_proj
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
def pop_key(k):
return state_dict.pop(k) if k in state_dict else None
in_proj_w_key = prefix + "in_proj_weight"
in_proj_b_key = prefix + "in_proj_bias"
weight = pop_key(in_proj_w_key)
if weight is not None:
q_w, k_w, v_w = torch.chunk(weight, 3)
self._q_proj.weight.data.copy_(q_w.to(self._q_proj.weight.device, dtype=self._q_proj.weight.dtype))
self._k_proj.weight.data.copy_(k_w.to(self._k_proj.weight.device, dtype=self._k_proj.weight.dtype))
self._v_proj.weight.data.copy_(v_w.to(self._v_proj.weight.device, dtype=self._v_proj.weight.dtype))
bias = pop_key(in_proj_b_key)
if bias is not None:
q_b, k_b, v_b = torch.chunk(bias, 3)
if getattr(self._q_proj, "bias", None) is not None:
self._q_proj.bias.data.copy_(q_b.to(self._q_proj.bias.device, dtype=self._q_proj.bias.dtype))
self._k_proj.bias.data.copy_(k_b.to(self._k_proj.bias.device, dtype=self._k_proj.bias.dtype))
self._v_proj.bias.data.copy_(v_b.to(self._v_proj.bias.device, dtype=self._v_proj.bias.dtype))
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, src, attn_mask = None, key_padding_mask = None):
q = self._q_proj(src)
k = self._k_proj(src)
v = self._v_proj(src)
output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask)
return self.out_proj(output)
# comfyui implementation of nn.TransformerEncoderLayer
class TransformerEncoderComfyv(nn.Module):
def __init__(self,
d_model, nhead, dim_feedforward,
norm_first = False,
layer_norm_eps: float = 1e-5,
bias: bool = True,
activation = F.relu,
device = None,
dtype = None, operations = None, **kwargs):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.linear1 = operations.Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear2 = operations.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.norm2 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.activation = activation
self.self_attn = MultiheadAttentionComfyv(
embed_dim = d_model,
num_heads = nhead,
bias = bias,
operations = operations,
**factory_kwargs
)
def forward(self, src, src_key_padding_mask = None, src_mask = None):
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(src_mask),
other_name="src_mask",
target_type=src.dtype,
)
src_mask = F._canonical_mask(
mask=src_mask,
mask_name="src_mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
x = src
if self.norm_first:
x = x + self._sa_block(
self.norm1(x), src_mask, src_key_padding_mask
)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(
x
+ self._sa_block(x, src_mask, src_key_padding_mask)
)
x = self.norm2(x + self._ff_block(x))
return x
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
) -> Tensor:
x = self.self_attn(
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)
return x
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
return self.linear2(self.activation(self.linear1(x)))

View File

@ -46,6 +46,7 @@ import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.higgsv2.model
import comfy.ldm.qwen_image.model
import comfy.ldm.hunyuan_foley.model
import comfy.model_management
import comfy.patcher_extension
@ -1350,6 +1351,10 @@ class ACEStep(BaseModel):
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
class HunyuanFoley(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.hunyuan_foley.model.HunyuanVideoFoley):
super().__init__(model_config, model_type, device, unet_model)
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@ -384,6 +384,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
return dit_config
if '{}triple_blocks.17.audio_cross_q.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Foley
return {}
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape

View File

@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_foley.vae
import yaml
import math
import os
@ -468,6 +469,12 @@ class VAE:
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# Hunyuan Foley
elif "syncformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd:
self.latent_dim = 128
self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae()
# TODO
self.memory_used_encode = lambda shape, dtype: shape[0] * model_management.dtype_size(dtype)
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)

View File

@ -4,6 +4,7 @@ from . import utils
from . import sd1_clip
from . import sdxl_clip
import comfy.clap_model
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
@ -1266,6 +1267,21 @@ class Omnigen2(supported_models_base.BASE):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class HunyuanFoley(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_foley",
}
latent_format = latent_formats.HunyuanFoley
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["dac."]
text_encoder_key_prefix = ["clap."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.HunyuanFoley(self, device=device)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.clap_model.ClapLargeTokenizer, comfy.clap_model.ClapTextEncoderModel)
class QwenImage(supported_models_base.BASE):
unet_config = {

View File

@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
def get_duration(self, retrun_frames = False) -> float:
"""
Returns the duration of the video in seconds.
@ -100,14 +100,16 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
return float(container.duration / av.time_base) if not retrun_frames\
else float(container.duration / av.time_base), float(container.duration)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
return float(video_stream.frames / video_stream.average_rate) if not retrun_frames\
else float(video_stream.frames / video_stream.average_rate), float(video_stream.frames)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
@ -117,7 +119,8 @@ class VideoFromFile(VideoInput):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
return float(frame_count / video_stream.average_rate) if not retrun_frames\
else float(frame_count / video_stream.average_rate), float(frame_count)
raise ValueError(f"Could not determine duration for file '{self.__file}'")

View File

@ -0,0 +1,53 @@
import torch
import comfy.model_management
class EmptyLatentHunyuanFoley:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"length": ("INT", {"default": 12, "min": 1, "max": 15, "tooltip": "The length of the audio. The same length as the video."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent audios in the batch."}),
},
"optional": {"video": ("VIDEO")}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, length, batch_size, video = None):
if video is not None:
_, length = video.get_duration(return_frames = True)
length /= 25
shape = (batch_size, 128, int(50 * length))
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan_foley"}, )
class HunyuanFoleyConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"video_encoding_siglip": ("CONDITIONING",),
"video_encoding_synchformer": ("CONDITIONING",),
"text_encoding": ("CONDITIONING",)
},
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, video_encoding_1, video_encoding_2, text_encoding):
embeds = torch.cat([video_encoding_1, video_encoding_2, text_encoding], dim = 0)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
NODE_CLASS_MAPPINGS = {
"HunyuanFoleyConditioning": HunyuanFoleyConditioning,
"EmptyLatentHunyuanFoley": EmptyLatentHunyuanFoley,
}

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import os
import io
import av
import torch
import folder_paths
import json
import numpy as np
from typing import Optional
from typing_extensions import override
from fractions import Fraction
@ -14,6 +16,116 @@ from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
class EncodeVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EncodeVideo",
display_name="Encode Video",
category="image/video",
description="Encode a video using an image encoder.",
inputs=[
io.Video.Input("video", tooltip="The video to be encoded."),
io.Int.Input(
"processing_batch_size", default=-1, min=-1,
tooltip=(
"Number of frames/segments to process at a time during encoding.\n"
"-1 means process all at once. Smaller values reduce GPU memory usage."
),
),
io.Int.Input("step_size", default=8, min=1, max=32,
tooltip=(
"Stride (in frames) between the start of consecutive segments.\n"
"Smaller step = more overlap and smoother temporal coverage "
"but higher compute cost. Larger step = faster but may miss detail."
),
),
io.Vae.Input("vae", optional=True),
io.ClipVision.Input("clip_vision", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="encoded_video"),
],
)
@classmethod
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
b, t, c, h, w = video.shape
batch_size = b * t
if vae is None and clip_vision is None:
raise ValueError("Must either have vae or clip_vision.")
vae = vae if vae is not None else clip_vision
if hasattr(vae.first_stage_model, "video_encoding"):
data, num_segments, output_fn = vae.video_encoding(video, step_size)
batch_size = b * num_segments
else:
data = video.view(batch_size, c, h, w)
output_fn = lambda x: x.view(b, t, -1)
if processing_batch_size != -1:
batch_size = processing_batch_size
outputs = []
total = data.shape[0]
for i in range(0, total, batch_size):
chunk = data[i : i + batch_size]
out = vae.encode(chunk)
outputs.append(out)
output = torch.cat(outputs)
return output_fn(output)
class ResampleVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ResampleVideo",
display_name="Resample Video",
category="image/video",
inputs = [
io.Video.Input("video"),
io.Int.Input("target_fps")
],
outputs=[io.Image.Output(display_name="images")]
)
@classmethod
def execute(cls, container: av.container.InputContainer, target_fps: int):
# doesn't support upsampling
stream = container.streams.video[0]
frames = []
src_rate = stream.average_rate or stream.guessed_rate
src_fps = float(src_rate) if src_rate else None
# yield original frames if asked for upsampling or src is unknown
if src_fps is None or target_fps > src_fps:
for packet in container.demux(stream):
for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
return torch.stack(frames)
stream.thread_type = "AUTO"
next_time = 0.0
step = 1.0 / target_fps
for packet in container.demux(stream):
for frame in packet.decode():
if frame.time is None:
continue
t = frame.time
while t >= next_time:
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
next_time += step
return torch.stack(frames)
class SaveWEBM(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -212,6 +324,8 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
EncodeVideo,
ResampleVideo
]
async def comfy_entrypoint() -> VideoExtension:

View File

@ -2378,6 +2378,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_hunyuan_foley.py"
]
import_failed = []