diff --git a/comfy/clap_model.py b/comfy/clap_model.py new file mode 100644 index 000000000..dd3b014ae --- /dev/null +++ b/comfy/clap_model.py @@ -0,0 +1,356 @@ +from typing import Optional, Callable + +import torch +from torch import nn + +from comfy.text_encoders.bert import ( + BertIntermediate as ClapTextIntermediate, + BertOutput as ClapTextOutput, +) +from comfy.ldm.modules.attention import optimized_attention + +import os +from comfy import sd1_clip +from transformers import AutoTokenizer + +def apply_chunking_to_forward( + forward_fn: Callable[..., torch.Tensor], + chunk_size: int, + chunk_dim: int, + *input_tensors, +) -> torch.Tensor: + + if chunk_size > 0: + # chunk into tuples and apply forward_fn to each element in the tuple + num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size + input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) + + output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) + return torch.cat(output_chunks, dim=chunk_dim) + + return forward_fn(*input_tensors) + +# from modeling_roberta +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + +class ClapProjectionLayer(nn.Module): + def __init__(self, hidden_size, projection_dim, device, dtype, operations): + super().__init__() + self.linear1 = operations.Linear(hidden_size, projection_dim, device=device, dtype=dtype) + self.activation = torch.nn.ReLU() + self.linear2 = operations.Linear(projection_dim, projection_dim, device=device, dtype=dtype) + + def forward(self, hidden_states): + hidden_states = self.linear1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +# same as RobertaEmbeddings +class ClapTextEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, device, dtype, operations): + super().__init__() + self.word_embeddings = operations.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, device=device, dtype=dtype) + self.position_embeddings = operations.Embedding(max_position_embeddings, hidden_size, device=device, dtype=dtype) + self.token_type_embeddings = operations.Embedding(type_vocab_size, hidden_size, device=device, dtype=dtype) + + self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.register_buffer( + "position_ids", torch.arange(max_position_embeddings, device=device, dtype=dtype).expand((1, -1)), persistent=True + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=device), persistent=True + ) + + # End copy + self.padding_idx = pad_token_id + self.position_embeddings = operations.Embedding( + max_position_embeddings, hidden_size, padding_idx=self.padding_idx, device=device, dtype=dtype + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + +# same as AlignTextSelfAttention +class ClapTextSelfAttention(nn.Module): + def __init__(self, num_attention_heads, hidden_size, device, dtype, operations): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype) + self.key = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype) + self.value = operations.Linear(hidden_size, self.all_head_size, device=device, dtype=dtype) + + self.scaling = self.attention_head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states, key_states, value_states = [t.contiguous() for t in (query_states, key_states, value_states)] + attn_output = optimized_attention(query_states, key_states, value_states, self.num_attention_heads, mask = attention_mask, skip_output_reshape=True, skip_reshape=True) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output.reshape(*input_shape, -1).contiguous() + +class ClapTextSelfOutput(nn.Module): + def __init__(self, hidden_size, layer_norm_eps, device, dtype, operations): + super().__init__() + self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype) + self.LayerNorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# same as AlignTextAttention +class ClapTextAttention(nn.Module): + def __init__(self, num_attention_heads, hidden_size, layer_norm_eps, device, dtype, operations): + super().__init__() + self.self = ClapTextSelfAttention(num_attention_heads, hidden_size, device=device, dtype=dtype, operations = operations) + self.output = ClapTextSelfOutput(hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + return self.output(self_outputs, hidden_states) + +# same as AlignTextLayer +class ClapTextLayer(nn.Module): + def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device, dtype, operations): + super().__init__() + # maybe we could allow chunking dynamically + self.chunk_size_feed_forward = 0 + self.seq_len_dim = 1 + self.attention = ClapTextAttention(num_attention_heads, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations) + self.intermediate = ClapTextIntermediate(hidden_size, intermediate_size, device=device, dtype=dtype, operations=operations) + self.output = ClapTextOutput(intermediate_size, hidden_size, layer_norm_eps, device=device, dtype=dtype, operations=operations) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + return layer_output + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# same as AlignTextEncoder +class ClapTextEncoder(nn.Module): + def __init__(self, num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device, dtype, operations): + super().__init__() + self.layer = nn.ModuleList([ClapTextLayer(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, device=device, dtype=dtype, operations=operations) + for _ in range(num_hidden_layers)]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ): + + for _, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + **kwargs, + ) + return hidden_states + +class ClapTextPooler(nn.Module): + def __init__(self, hidden_size, device, dtype, operations): + super().__init__() + self.dense = operations.Linear(hidden_size, hidden_size, device=device, dtype=dtype) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class ClapTextModel(nn.Module): + + def __init__(self, num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, num_hidden_layers, + device, dtype, operations, add_pooling_layer=True): + super().__init__() + + self.embeddings = ClapTextEmbeddings(vocab_size, hidden_size, pad_token_id, max_position_embeddings, type_vocab_size, layer_norm_eps, + device=device, dtype=dtype, operations=operations,) + self.encoder = ClapTextEncoder(num_attention_heads, hidden_size, intermediate_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations,) + self.pooler = ClapTextPooler(hidden_size, device=device, dtype=dtype, operations=operations,) if add_pooling_layer else None + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + + if input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + ) + sequence_output = encoder_outputs + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return sequence_output, pooled_output + +class ClapTextModelWithProjection(nn.Module): + def __init__( + self, + hidden_size: int = 768, + intermediate_size: int = 3072, + layer_norm_eps: float = 1e-12, + max_position_embeddings: int = 514, + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + projection_dim: int = 512, + type_vocab_size: int = 1, + vocab_size: int = 50265, + pad_token_id: int = 1, + device=None, + dtype=None, + operations=None + ): + super().__init__() + self.text_model = ClapTextModel(num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings, + type_vocab_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations) + self.text_projection = ClapProjectionLayer(hidden_size, projection_dim, device=device, dtype=dtype, operations=operations,) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + pooled_output = text_outputs[1] + text_embeds = self.text_projection(pooled_output) + + return text_embeds, text_outputs[0] + +class ClapTextEncoderModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 1}, layer_norm_hidden_state=False, model_class=ClapTextModelWithProjection, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class ClapLargeTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 0d84994b0..865bbc7df 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -551,3 +551,7 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class HunyuanFoley(LatentFormat): + latent_dimensions = 128 + latent_channels = 1024 \ No newline at end of file diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py new file mode 100644 index 000000000..234415d13 --- /dev/null +++ b/comfy/ldm/hunyuan_foley/model.py @@ -0,0 +1,921 @@ +from typing import List, Tuple, Optional, Union, Dict +from functools import partial + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange + +from comfy.ldm.modules.attention import optimized_attention as attention +from comfy.ldm.aura.mmdit import TimestepEmbedder as TimestepEmbedderParent +from comfy.ldm.hydit.posemb_layers import get_1d_rotary_pos_embed + +from typing import Union, Tuple + +# to get exact matching results +# only difference is the upscale to float32 +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + +class TimestepEmbedder(TimestepEmbedderParent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + +class SwiGLU(nn.Module): + def __init__(self, dim: int, hidden_dim: int, device, dtype, operations): + super().__init__() + self.w1 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.w2 = operations.Linear(hidden_dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + ndim = x.ndim + if head_first: + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + +class ConditionProjection(nn.Module): + def __init__(self, in_channels, hidden_size, dtype=None, device=None, operations = None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.linear_1 = operations.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.act_1 = nn.SiLU() + self.linear_2 = operations.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + + def forward(self, caption): + return self.linear_2(self.act_1(self.linear_1(caption))) + +class PatchEmbed1D(nn.Module): + def __init__( + self, + patch_size=1, + in_chans=768, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + operations = None + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.flatten = flatten + + self.proj = operations.Conv1d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs + ) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.transpose(1, 2) + x = self.norm(x) + return x + +# avoid classifying as wrapper to work with operations.conv1d +class ChannelLastConv1d(nn.Module): + def __init__(self, in_channels, out_channels, bias=True, kernel_size = 3, padding = 0, device=None, dtype=None, operations=None): + super().__init__() + + operations = operations or nn + underlying = operations.Conv1d( + in_channels, out_channels, kernel_size = kernel_size, padding = padding, + bias=bias, device=device, dtype=dtype + ) + + self.register_parameter("weight", underlying.weight) + if getattr(underlying, "bias", None) is not None: + self.register_parameter("bias", underlying.bias) + else: + self.register_parameter("bias", None) + + object.__setattr__(self, "_underlying", underlying) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._underlying(x.permute(0, 2, 1)) + return x.permute(0, 2, 1) + +class ConvMLP(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + device=None, + dtype=None, + operations = None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs) + self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs) + self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, operations = operations, **factory_kwargs) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + +def modulate(x, shift=None, scale=None): + if x.ndim == 3: + shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None + scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale) + elif scale is None: + return x + shift + else: + return x * (1 + scale) + shift + +class ModulateDiT(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations = None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = nn.SiLU() + self.linear = operations.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + +class FinalLayer1D(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels, device=None, dtype=None, operations = None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.linear = operations.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + +class MLP(nn.Module): + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + operations = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = (bias, bias) + drop_probs = (drop, drop) + linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear + + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) + self.act = nn.GELU(approximate="tanh") + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + return self.drop2(self.fc2(self.norm(self.drop1(self.act(self.fc1(x)))))) + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + +def get_meshgrid_nd(start, *args, dim=2): + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + grid = torch.stack(grid, dim=0) + + return grid + +def get_nd_rotary_pos_embed( + rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 +): + + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + freq_scaling=freq_scaling, + ) + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) + sin = torch.cat([emb[1] for emb in embs], dim=1) + return cos, sin + else: + emb = torch.cat(embs, dim=1) + return emb + +def apply_gate(x, gate = None): + if gate is None: + return x + if gate.ndim == 2 and x.ndim == 3: + gate = gate.unsqueeze(1) + return x * gate + +def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor): + B, N1, H, C = x1.shape + B, N2, H, C = x2.shape + assert x1.ndim == x2.ndim == 4 + + if N1 != N2: + x2 = x2.view(B, N2, -1).transpose(1, 2) + x2 = F.interpolate(x2, size=(N1), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, N1, H, C) + x = torch.stack((x1, x2), dim=2) + x = x.reshape(B, N1 * 2, H, C) + return x + +def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int): + B, N, H, C = x.shape + assert N % 2 == 0 and N // 2 == len1 + + x = x.reshape(B, -1, 2, H, C) + x1 = x[:, :, 0] + x2 = x[:, :, 1] + if x2.shape[1] != len2: + x2 = x2.view(B, len1, H * C).transpose(1, 2) + x2 = F.interpolate(x2, size=(len2), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, len2, H, C) + return x1, x2 + +def apply_modulated_block(x, norm_layer, shift, scale, mlp_layer, gate): + x_mod = modulate(norm_layer(x), shift=shift, scale=scale) + return x + apply_gate(mlp_layer(x_mod), gate=gate) + +def prepare_self_attn_qkv(x, norm_layer, qkv_layer, q_norm, k_norm, shift, scale, num_heads): + x_mod = modulate(norm_layer(x), shift=shift, scale=scale) + qkv = qkv_layer(x_mod) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=num_heads) + + q = q_norm(q).to(v) + k = k_norm(k).to(v) + return q, k, v + +class TwoStreamCABlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qk_norm: bool = True, + qkv_bias: bool = False, + interleaved_audio_visual_rope: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations = None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.num_heads = num_heads + self.hidden_size = hidden_size + head_dim = hidden_size // num_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.interleaved_audio_visual_rope = interleaved_audio_visual_rope + + self.audio_mod = ModulateDiT(hidden_size, factor=9, operations = operations, **factory_kwargs) + self.audio_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_self_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + + def make_qk_norm(name: str): + layer = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + setattr(self, name, layer) + + for name in ["v_cond_attn_q_norm", "v_cond_attn_k_norm", "audio_cross_q_norm", + "v_cond_cross_q_norm", "text_cross_k_norm", "audio_self_q_norm", "audio_self_k_norm"]: + make_qk_norm(name) + + self.audio_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.v_cond_mod = ModulateDiT(hidden_size, factor = 9, operations = operations, **factory_kwargs) + self.v_cond_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_attn_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + + self.v_cond_self_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.max_text_len = 100 + self.rope_dim_list = None + + self.audio_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.audio_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.text_cross_kv = operations.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) + + self.audio_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.audio_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_mlp = MLP( + hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs + ) + + self.v_cond_norm3 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_mlp = MLP( + hidden_size, mlp_hidden_dim, bias=True, operations = operations, **factory_kwargs + ) + + def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): + target_ndim = 1 # n-d RoPE + rope_sizes = [text_len] + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + + text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return text_freqs_cos, text_freqs_sin + + def forward( + self, + audio: torch.Tensor, + cond: torch.Tensor, + v_cond: torch.Tensor, + attn_mask: torch.Tensor, + vec: torch.Tensor, + freqs_cis: tuple = None, + v_freqs_cis: tuple = None, + sync_vec: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, + audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, + ) = self.audio_mod(sync_vec if sync_vec is not None else vec).chunk(9, dim=-1) + + ( + v_cond_mod1_shift, + v_cond_mod1_scale, + v_cond_mod1_gate, + v_cond_mod2_shift, + v_cond_mod2_scale, + v_cond_mod2_gate, + v_cond_mod3_shift, + v_cond_mod3_scale, + v_cond_mod3_gate, + ) = self.v_cond_mod(vec).chunk(9, dim=-1) + + audio_q, audio_k, audio_v = prepare_self_attn_qkv( + audio, self.audio_norm1, self.audio_self_attn_qkv, + self.audio_self_q_norm, self.audio_self_k_norm, + audio_mod1_shift, audio_mod1_scale, self.num_heads + ) + + v_cond_q, v_cond_k, v_cond_v = prepare_self_attn_qkv( + v_cond, self.v_cond_norm1, self.v_cond_attn_qkv, + self.v_cond_attn_q_norm, self.v_cond_attn_k_norm, + v_cond_mod1_shift, v_cond_mod1_scale, self.num_heads + ) + + # Apply RoPE if needed for audio and visual + if freqs_cis is not None: + if not self.interleaved_audio_visual_rope: + audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False) + audio_q, audio_k = audio_qq, audio_kk + else: + ori_audio_len = audio_q.shape[1] + ori_v_con_len = v_cond_q.shape[1] + interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q) + interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k) + interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb( + interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False + ) + audio_qq, v_cond_qq = decouple_interleaved_two_sequences( + interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len + ) + audio_kk, v_cond_kk = decouple_interleaved_two_sequences( + interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len + ) + audio_q, audio_k = audio_qq, audio_kk + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: + v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + q = torch.cat((v_cond_q, audio_q), dim=1) + k = torch.cat((v_cond_k, audio_k), dim=1) + v = torch.cat((v_cond_v, audio_v), dim=1) + + # TODO: look further into here + if attention.__name__ == "attention_pytorch": + q, k, v = [t.transpose(1, 2) for t in (q, k, v)] + + attn = attention(q, k, v, heads = self.num_heads, mask=attn_mask, skip_reshape=True) + v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) + v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) + head_dim = self.hidden_size // self.num_heads + + audio_q = self.prepare_modulated_query(audio, self.audio_norm2, self.audio_cross_q, + self.audio_cross_q_norm, audio_mod2_shift, audio_mod2_scale, + self.num_heads, self.rope_dim_list) + + v_cond_q = self.prepare_modulated_query(v_cond, self.v_cond_norm2, self.v_cond_cross_q, + self.v_cond_cross_q_norm, v_cond_mod2_shift, v_cond_mod2_scale, + self.num_heads, self.rope_dim_list) + + text_kv = self.text_cross_kv(cond) + text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads) + text_k = self.text_cross_k_norm(text_k).to(text_v) + + text_len = text_k.shape[1] + + text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, + rope_dim_list=self.rope_dim_list) + text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) + text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] + + v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) + + if attention.__name__ == "attention_pytorch": + v_cond_audio_q, text_k, text_v = [t.transpose(1, 2) for t in (v_cond_audio_q, text_k, text_v)] + + cross_attn = attention(v_cond_audio_q, text_k, text_v, self.num_heads, skip_reshape = True) + v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) + v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) + + audio = apply_modulated_block(audio, self.audio_norm3, audio_mod3_shift, audio_mod3_scale, self.audio_mlp, audio_mod3_gate) + v_cond = apply_modulated_block(v_cond, self.v_cond_norm3, v_cond_mod3_shift, v_cond_mod3_scale, self.v_cond_mlp, v_cond_mod3_gate) + + return audio, cond, v_cond + + def prepare_modulated_query(self, x, norm_layer, q_layer, q_norm_layer, shift, scale, num_heads, rope_dim_list): + + x_mod = modulate(norm_layer(x), shift=shift, scale=scale) + q = q_layer(x_mod) + + q = rearrange(q, "B L (H D) -> B L H D", H=num_heads) + q = q_norm_layer(q) + + head_dim = q.shape[-1] + freqs_cos, freqs_sin = self.build_rope_for_text(q.shape[1], head_dim, rope_dim_list) + freqs_cis = (freqs_cos.to(q.device), freqs_sin.to(q.device)) + + q = apply_rotary_emb(q, q, freqs_cis, head_first=False)[0] + + return q + +class SingleStreamBlock(nn.Module): + + def __init__(self, hidden_size: int, + num_heads: int, + mlp_ratio: float, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations = None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + + self.modulation = ModulateDiT( + hidden_size=hidden_size, + factor=6, + operations = operations, + **factory_kwargs, + ) + self.linear_qkv = operations.Linear(hidden_size, hidden_size * 3, bias=True) + self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, operations = operations, **factory_kwargs) + self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, operations = operations, **factory_kwargs) + self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs) + self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, **factory_kwargs) + self.q_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs) + self.k_norm = operations.RMSNorm(hidden_size // num_heads, **factory_kwargs) + self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads) + + def forward(self, x: torch.Tensor, cond: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None): + + modulation = self.modulation(cond) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) + x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa + + qkv = self.linear_qkv(x_norm1) + q, k, v = self.rearrange(qkv).chunk(3, dim=-1) + + q, k, v = [t.squeeze(-1) for t in (q, k, v)] + + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True) + + q, k, v = [t.contiguous() for t in (q, k, v)] + + out = attention(q, k, v, self.num_heads, skip_output_reshape = True, skip_reshape = True) + out = rearrange(out, 'b h n d -> b n (h d)').contiguous() + + x = x + apply_gate(self.linear1(out),gate=gate_msa) + x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp + x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp) + + return x + +class HunyuanVideoFoley(nn.Module): + def __init__( + self, + model_args, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + operations = None + ): + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.depth_triple_blocks = 18 + self.depth_single_blocks = 36 + + self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", True) + + self.condition_dim = model_args.get("condition_dim", 768) + + self.patch_size = model_args.get("patch_size", 1) + self.visual_in_channels = model_args.get("clip_dim", 768) + self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) + self.out_channels = self.audio_vae_latent_dim + self.unpatchify_channels = self.out_channels + + self.num_heads = model_args.get("num_heads", 12) + self.hidden_size = model_args.get("hidden_size", 1536) + self.rope_dim_list = model_args.get("rope_dim_list", None) + self.mlp_ratio = model_args.get("mlp_ratio", 4.0) + + self.qkv_bias = model_args.get("qkv_bias", True) + self.qk_norm = model_args.get("qk_norm", True) + + # sync condition things + self.sync_modulation = model_args.get("sync_modulation", False) + self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", True) + self.sync_feat_dim = model_args.get("sync_feat_dim", 768) + self.sync_in_ksz = model_args.get("sync_in_ksz", 1) + + self.clip_len = model_args.get("clip_length", 64) + self.sync_len = model_args.get("sync_length", 192) + + self.patch_size = 1 + self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, operations=operations, **factory_kwargs) + self.visual_proj = SwiGLU(dim = self.visual_in_channels, hidden_dim = self.hidden_size, device=device, dtype=dtype, operations=operations) + + self.cond_in = ConditionProjection( + self.condition_dim, self.hidden_size, operations=operations, **factory_kwargs + ) + + self.time_in = TimestepEmbedder(self.hidden_size, operations = operations, **factory_kwargs) + + # visual sync embedder if needed + if self.sync_in_ksz == 1: + sync_in_padding = 0 + elif self.sync_in_ksz == 3: + sync_in_padding = 1 + else: + raise ValueError + if self.sync_modulation or self.add_sync_feat_to_audio: + self.sync_in = nn.Sequential( + operations.Linear(self.sync_feat_dim, self.hidden_size, **factory_kwargs), + nn.SiLU(), + ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding, operations=operations, **factory_kwargs), + ) + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim), **factory_kwargs)) + + self.triple_blocks = nn.ModuleList( + [ + TwoStreamCABlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qk_norm=self.qk_norm, + qkv_bias=self.qkv_bias, + interleaved_audio_visual_rope=self.interleaved_audio_visual_rope, + operations=operations, + **factory_kwargs, + ) + for _ in range(self.depth_triple_blocks) + ] + ) + + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + operations=operations, + **factory_kwargs, + ) + for _ in range(self.depth_single_blocks) + ] + ) + + self.final_layer = FinalLayer1D( + self.hidden_size, self.patch_size, self.out_channels, operations = operations,**factory_kwargs + ) + self.unpatchify_channels = self.out_channels + + self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False) + + def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.clip_len + if bs is None: + return self.empty_clip_feat.expand(len, -1) # 15s + else: + return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s + + def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.sync_len + if bs is None: + return self.empty_sync_feat.expand(len, -1) + else: + return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1) + + def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len): + target_ndim = 1 # n-d RoPE + rope_sizes = [audio_emb_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + + target_ndim = 1 + rope_sizes = [visual_cond_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + freq_scaling=1.0 * audio_emb_len / visual_cond_len, + ) + return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin + + def build_rope_for_interleaved_audio_visual(self, total_len): + target_ndim = 1 # n-d RoPE + rope_sizes = [total_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return freqs_cos, freqs_sin + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + full_cond: torch.Tensor, + transformer_options = {}, + drop_visual: Optional[List[bool]] = None, + ): + audio = x + bs, _, ol = x.shape + tl = ol // self.patch_size + + condition, uncondition = torch.chunk(2, full_cond) + uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition) + clip_feat, sync_feat, cond = torch.chunk(3, condition) + clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond]) + + if drop_visual is not None: + clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) + sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype) + + vec = self.time_in(t) + sync_vec = None + if self.add_sync_feat_to_audio: + sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb.to(sync_feat.device) + sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) + sync_feat = self.sync_in.to(sync_feat.device)(sync_feat) + add_sync_feat_to_audio = ( + F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) + ) + + cond = self.cond_in(cond) + cond_seq_len = cond.shape[1] + + audio = self.audio_embedder(x) + audio_seq_len = audio.shape[1] + v_cond = self.visual_proj(clip_feat) + v_cond_seq_len = v_cond.shape[1] + attn_mask = None + + + freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2) + v_freqs_cos = v_freqs_sin = None + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None + + if self.add_sync_feat_to_audio: + add_sync_layer = 0 + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + def block_wrap(**kwargs): + return block(**kwargs) + + for layer_num, block in enumerate(self.triple_blocks): + if self.add_sync_feat_to_audio and layer_num == add_sync_layer: + audio = audio + add_sync_feat_to_audio + triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec] + if ("triple_block", layer_num) in blocks_replace: + audio, cond, v_cond = blocks_replace[("triple_block", layer_num)]({ + "audio": triple_block_args[0], + "cond": triple_block_args[1], + "v_cond": triple_block_args[2], + "attn_mask": triple_block_args[3], + "vec": triple_block_args[4], + "freqs_cis": triple_block_args[5], + "v_freqs_cis": triple_block_args[6], + "sync_vec": triple_block_args[7] + }, {"original_block": block_wrap}) + else: + audio, cond, v_cond = block(*triple_block_args) + + x = audio + if sync_vec is not None: + vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) + vec = torch.cat((vec, sync_vec), dim=1) + + freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len) + if self.add_sync_feat_to_audio: + vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1) + if len(self.single_blocks) > 0: + for layer_num, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + (freqs_cos, freqs_sin), + ] + if ("single_block", layer_num) in blocks_replace: + x = blocks_replace[("single_block", layer_num)]({ + "x": single_block_args[0], + "vec": single_block_args[1], + "freqs_cis": single_block_args[2] + }, {"original_block": block_wrap}) + else: + x = block(*single_block_args) + + audio = x + + if sync_vec is not None: + vec = sync_vec + audio = self.final_layer(audio, vec) + audio = self.unpatchify1d(audio, tl) + + uncond, cond = torch.chunk(2, audio) + return torch.cat([cond, uncond]) + + def unpatchify1d(self, x, l): + c = self.unpatchify_channels + p = self.patch_size + + x = x.reshape(shape=(x.shape[0], l, p, c)) + x = torch.einsum("ntpc->nctp", x) + audio = x.reshape(shape=(x.shape[0], c, l * p)) + return audio diff --git a/comfy/ldm/hunyuan_foley/syncformer.py b/comfy/ldm/hunyuan_foley/syncformer.py new file mode 100644 index 000000000..9f1973bfb --- /dev/null +++ b/comfy/ldm/hunyuan_foley/syncformer.py @@ -0,0 +1,991 @@ +#!/usr/bin/env python3 + +from functools import partial + +import torch +import torch.nn as nn +from einops import rearrange, repeat, einops + +from comfy.ldm.hunyuan3dv2_1.hunyuandit import MLP as Mlp +from transformers.modeling_outputs import BaseModelOutputWithPooling +from comfy.ldm.modules.attention import optimized_attention, TransformerEncoderComfyv +from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig + +from typing import Optional, Union, Tuple + +class Config: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + +class PatchEmbed(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, device=None, dtype=None, operations=None): + super().__init__() + img_size = img_size if type(img_size) is tuple else (img_size, img_size) + patch_size = img_size if type(patch_size) is tuple else (patch_size, patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + def __init__( + self, + img_size=224, + in_chans=3, + patch_size=16, + z_block_size=2, + embed_dim=768, + flatten=True, + device=None, dtype=None, operations=None + ): + super().__init__() + self.height = img_size // patch_size + self.width = img_size // patch_size + self.z_block_size = z_block_size + self.proj = operations.Conv3d( + in_chans, + embed_dim, + kernel_size=(z_block_size, patch_size, patch_size), + stride=(z_block_size, patch_size, patch_size), + device=device, dtype=dtype + ) + self.flatten = flatten + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + return x + +def qkv_attn(q, k, v, heads): + bh, seq_q, dim_head = q.shape + b = bh // heads + + # (b*heads, seq, dim) -> (b, heads, seq, dim) + q2 = q.view(b, heads, seq_q, dim_head) + k2 = k.view(b, heads, k.shape[1], dim_head) + v2 = v.view(b, heads, v.shape[1], dim_head) + + out = optimized_attention(q2, k2, v2, heads=heads, skip_reshape=True) + + out = out.permute(0, 2, 1, 3).contiguous().view(b * heads, seq_q, dim_head) + + return out + + +class DividedAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): + h = self.num_heads + + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + + cls_out = qkv_attn(cls_q, k, v, self.num_heads) + + q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) + + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + out = qkv_attn(q_, k_, v_, self.num_heads) + out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) + + out = torch.cat((cls_out, out), dim=1) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + x = self.proj(out) + return x + +class DividedSpaceTimeBlock(nn.Module): + + def __init__( + self, + dim=768, + num_heads=12, + qkv_bias=False, + norm_layer=nn.LayerNorm, + device = None, dtype = None, operations = None + ): + super().__init__() + + factory_kwargs = {"device":device, "dtype": dtype} + + self.einops_from_space = "b (f n) d" + self.einops_to_space = "(b f) n d" + self.einops_from_time = "b (f n) d" + self.einops_to_time = "(b n) f d" + + self.norm1 = norm_layer(dim) + + self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, operations = operations, **factory_kwargs) + + self.timeattn = DividedAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, operations=operations, **factory_kwargs + ) + + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(width = dim, operations = operations, device=device, dtype=dtype) + self.norm3 = norm_layer(dim) + + def forward(self, x, seq_len=196, num_frames=8, tok_mask: torch.Tensor = None): + time_output = self.timeattn( + self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask + ) + time_residual = x + time_output + + space_output = self.attn( + self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask + ) + space_residual = time_residual + self.drop_path(space_output) + + x = space_residual + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class MotionFormer(nn.Module): + def __init__(self, device = None, dtype = None, operations = None): + super().__init__() + self.APPROX_ATTN_TYPE = "none" + self.APPROX_ATTN_DIM = 64 + self.img_size = 224 + self.patch_size = 16 + self.in_chans = 3 + self.num_classes = 174 + self.embed_dim = 768 + self.depth = 12 + self.num_heads = 12 + self.mlp_ratio = 4 + self.qkv_bias = True + self.drop_rate = 0.0 + self.drop_path_rate = 0.2 + self.temporal_resolution = 8 + self.use_mlp = True + self.num_features = self.embed_dim + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.attn_drop_rate = 0.0 + self.factorize_space_time = True + + # Patch Embedding + self.patch_embed = PatchEmbed( + img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim, + device=device, dtype=dtype, operations=operations + ) + + # 3D Patch Embedding + self.patch_embed_3d = PatchEmbed3D( + img_size=self.img_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + z_block_size = 2, + device=device, dtype=dtype, operations=operations + ) + self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) + + # Number of patches + self.num_patches = self.patch_embed.num_patches * self.temporal_resolution + + # CLS token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, device=device, dtype=dtype)) + + # Positional embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim, device=device, dtype=dtype)) + self.pos_drop = nn.Dropout(p=0.0) + + self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim, device=device, dtype=dtype)) + + self.blocks = nn.ModuleList( + [ + DividedSpaceTimeBlock( + dim=self.embed_dim, + num_heads=self.num_heads, + qkv_bias=self.qkv_bias, + norm_layer=norm_layer, + device=device, dtype=dtype, operations=operations + ) + for _ in range(self.depth) + ] + ) + + self.norm = norm_layer(self.embed_dim) + + self.pre_logits = nn.Identity() + + transf_enc_layer_kwargs = dict( + d_model=self.embed_dim, + nhead=self.num_heads, + activation=nn.GELU(), + batch_first=True, + dim_feedforward=self.mlp_ratio * self.embed_dim, + dropout=self.drop_rate, + layer_norm_eps=1e-6, + norm_first=True, + ) + self.spatial_attn_agg = SpatialTransformerEncoderLayer(device = device, dtype=dtype, operations=operations,**transf_enc_layer_kwargs) + self.temp_attn_agg = nn.Identity() + + def forward_features(self, x): + + B = x.shape[0] + + # apply patching on input + x = self.patch_embed_3d(x) + tok_mask = None + + # Append CLS token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + new_pos_embed = self.pos_embed + npatch = self.patch_embed.num_patches + + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) + tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + x = x + total_pos_embed + + # Apply positional dropout + x = self.pos_drop(x) + + # Encoding using transformer layers + for i, blk in enumerate(self.blocks): + x = blk( + x, + seq_len=npatch, + num_frames=self.temporal_resolution, + tok_mask=tok_mask, + ) + + return x, tok_mask + + def forward(self, x): + B, S, C, T, H, W = x.shape + + orig_shape = (B, S, C, T, H, W) + x = x.view(B * S, C, T, H, W) # flatten batch and segments + x = self.forward_segments(x, orig_shape=orig_shape) + x = x.view(B, S, *x.shape[1:]) + + return x + + def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor: + x, x_mask = self.forward_features(x) + + x = x[:, 1:, :] + x = self.norm(x) + x = self.pre_logits(x) + if self.factorize_space_time: + x = self.restore_spatio_temp_dims(x, orig_shape) + + x = self.spatial_attn_agg(x, x_mask) + x = self.temp_attn_agg(x) + + return x + + def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + + B, S, C, T, H, W = orig_shape + D = self.embed_dim + + # num patches in each dimension + t = T // self.patch_embed_3d.z_block_size + h = self.patch_embed_3d.height + w = self.patch_embed_3d.width + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w) + + return feats + +class BaseEncoderLayer(TransformerEncoderComfyv): + def __init__( + self, + add_pos_emb: bool = False, + pos_emb_drop: float = None, + pos_max_len: int = None, + device = None, + dtype = None, operations = None, + *args, **kwargs + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(operations = operations, *args, **kwargs, **factory_kwargs) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim, **factory_kwargs)) + + self.add_pos_emb = add_pos_emb + if add_pos_emb: + self.pos_max_len = 1 + pos_max_len + self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim, **factory_kwargs)) + self.pos_drop = nn.Dropout(pos_emb_drop) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + batch_dim = x.shape[0] + + cls_tokens = self.cls_token.expand(batch_dim, -1, -1) + x = torch.cat((cls_tokens, x), dim=-2) + if x_mask is not None: + cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device) + x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) + B, N = x_mask_w_cls.shape + x_mask_w_cls = ( + x_mask_w_cls.reshape(B, 1, 1, N) + .expand(-1, self.self_attn.num_heads, N, -1) + .reshape(B * self.self_attn.num_heads, N, N) + ) + assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool" + x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask) + else: + x_mask_w_cls = None + + # add positional embedding + if self.add_pos_emb: + seq_len = x.shape[1] + assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})" + x = x + self.pos_emb[:, :seq_len, :] + x = self.pos_drop(x) + + x = super().forward(src=x, src_mask=x_mask_w_cls) + + x = x[:, 0, :] + + return x + +class SpatialTransformerEncoderLayer(BaseEncoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + BS, D, t, h, w = x.shape + + x = rearrange(x, "BS D t h w -> (BS t) (h w) D") + if x_mask is not None: + x_mask = rearrange(x_mask, "BS t h w -> (BS t) (h w)") + + x = super().forward(x=x, x_mask=x_mask) + + x = rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t) + + return x + +class AST(torch.nn.Module): + def __init__( + self, + max_spec_t: int = None, + factorize_freq_time: bool = None, + max_segments: int = None, + device = None, dtype = None, operations = None + ) -> None: + + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.extract_features = True + self.max_spec_t = max_spec_t + self.max_segments = max_segments + + self.config = ASTConfig() + self.config.num_labels = 527 + + self.ast = ASTModel(self.config, device=device, dtype=dtype, operations=operations) + + self.feat_type = "last_hidden_state" + self.factorize_freq_time = factorize_freq_time + + transf_enc_layer_kwargs = dict( + d_model=self.config.hidden_size, + nhead=self.config.num_attention_heads, + dim_feedforward=self.config.intermediate_size, + activation=torch.nn.GELU(), + batch_first=True, + dropout=self.config.attention_probs_dropout_prob, + layer_norm_eps=1e-6, + norm_first=True, + ) + if factorize_freq_time: + self.feat_type = "last_hidden_state" + self.freq_attn_agg = FrequencyTransformerEncoderLayer(operations = operations, **transf_enc_layer_kwargs, **factory_kwargs) + self.temp_attn_agg = torch.nn.Identity() + + self.device = device + + self.patch_position_emb() + + def forward( + self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs + ) -> torch.Tensor: + + B, S, T, F = x.shape + + if for_loop: + assert cont_mask is None, "cont_mask is not supported with for_loop=True" + orig_shape_s = (B, 1, T, F) + x = torch.cat( + [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1 + ) + else: + orig_shape = (B, S, T, F) + x = x.view(B * S, T, F) + if cont_mask is not None: + cont_mask = cont_mask.reshape(B * S, T, F) + x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs) + x = x.view(B, S, *x.shape[1:]) + + global_x = None + + return x, global_x + + def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs): + + x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs) + + if self.extract_features: + x = self.get_features_by_type(x) + if self.factorize_freq_time: + x = self.restore_freq_temp_dims(x, orig_shape) + if cont_mask is not None: + x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) + x_mask = x_mask[:, 0, :, :] + else: + x_mask = None + x = self.freq_attn_agg(x, x_mask) + x = self.temp_attn_agg(x) + else: + x = x["pooler_output"] + x = self.classifier(x) + return x + + def get_features_by_type(self, x) -> torch.Tensor: + return x["last_hidden_state"] # (B, 2+T, D) + + def restore_freq_temp_dims(self, feats, orig_shape: tuple): + B, S, T, F = orig_shape + D = self.config.hidden_size + + # num patches in each dimension + f, t = self.ast.embeddings.get_shape(self.config) + + if self.feat_type == "last_hidden_state": + feats = feats[:, 2:, :] # removing CLS and distill tokens + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, f, t) # (B*S, D, f, t) + + return feats + + def patch_position_emb(self): + if self.max_spec_t is not None: + self.config.max_length = self.max_spec_t + f, t = self.ast.embeddings.get_shape(self.config) + shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens + self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device) + + def to(self, device): + self.device = torch.device(device) + return super().to(device) + + +class FrequencyTransformerEncoderLayer(BaseEncoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + BS, D, f, t = x.shape + + x = x.permute(0, 3, 2, 1) + x = x.reshape(BS * t, f, D) + if x_mask is not None: + x_mask = x_mask.permute(0, 2, 1) + x_mask = x_mask.reshape(BS * t, f) + + x = super().forward(x=x, x_mask=x_mask) + + x = x.view(BS, t, D) + + return x + +class ASTEmbeddings(nn.Module): + + def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = ASTPatchEmbeddings(config, device, dtype, operations) + + frequency_out_dimension, time_out_dimension = self.get_shape(config) + num_patches = frequency_out_dimension * time_out_dimension + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size, device=device, dtype=dtype)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def get_shape(self, config): + frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1 + time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1 + + return frequency_out_dimension, time_out_dimension + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + batch_size = input_values.shape[0] + embeddings = self.patch_embeddings(input_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +class ASTPatchEmbeddings(nn.Module): + def __init__(self, config, device = None, dtype = None, operations = None): + super().__init__() + + + patch_size = config.patch_size + frequency_stride = config.frequency_stride + time_stride = config.time_stride + + self.projection = operations.Conv2d( + 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride), device = device, dtype = dtype + ) + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + input_values = input_values.unsqueeze(1) + input_values = input_values.transpose(2, 3) + embeddings = self.projection(input_values).flatten(2).transpose(1, 2) + return embeddings + + +class ASTSelfAttention(nn.Module): + def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None: + super().__init__() + factory_kwargs = { "device": device, "dtype": dtype } + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs) + self.key = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs) + self.value = operations.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias, **factory_kwargs) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + if tok_mask is not None: + attn_mask = (tok_mask == 0) + attn_mask = attn_mask[:, None, None, :] + else: + attn_mask = None + context_layer = optimized_attention(query_layer, key_layer, value_layer, self.num_attention_heads, mask = attn_mask, skip_output_reshape=True, skip_reshape=True) + context_layer = context_layer.view(*query_layer.size()) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return (context_layer,) + +class ASTSelfOutput(nn.Module): + + def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.dense = operations.Linear(config.hidden_size, config.hidden_size, device=device, dtype=dtype) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + +class ASTAttention(nn.Module): + def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.attention = ASTSelfAttention(config, device=device, dtype=dtype, operations=operations) + self.output = ASTSelfOutput(config, device=device, dtype=dtype, operations=operations) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, tok_mask, head_mask) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class ASTIntermediate(nn.Module): + def __init__(self, config: ASTConfig, device, dtype, operations) -> None: + super().__init__() + self.dense = operations.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.intermediate_act_fn = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ASTOutput(nn.Module): + def __init__(self, config: ASTConfig, device, dtype, operations) -> None: + super().__init__() + self.dense = operations.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + +class ASTLayer(nn.Module): + def __init__(self, config: ASTConfig, device=None, dtype=None, operations=None) -> None: + super().__init__() + factory_kwargs = {"device":device, "dtype":dtype} + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ASTAttention(config, operations = operations, **factory_kwargs) + self.intermediate = ASTIntermediate(config, operations=operations, **factory_kwargs) + self.output = ASTOutput(config, operations=operations, **factory_kwargs) + self.layernorm_before = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs) + self.layernorm_after = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, **factory_kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), + tok_mask, + head_mask, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] + + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ASTEncoder(nn.Module): + def __init__(self, config: ASTConfig, device, dtype, operations) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ASTLayer(config, device, dtype, operations) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + + for i, layer_module in enumerate(self.layer): + layer_head_mask = head_mask[i] if head_mask is not None else None + layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions) + hidden_states = layer_outputs[0] + + return hidden_states + +class ASTModel(nn.Module): + def __init__(self, config: ASTConfig, device, dtype, operations): + super().__init__() + self.config = config + + self.embeddings = ASTEmbeddings(config, device, dtype, operations) + self.encoder = ASTEncoder(config, device, dtype, operations) + + self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def get_input_embeddings(self) -> ASTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def forward( + self, + input_values: Optional[torch.Tensor] = None, + cont_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ): + + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings(input_values) + + if cont_mask is not None: + indicator = torch.ones_like(input_values).to(input_values.dtype) + indicator[~cont_mask] = torch.inf + with torch.no_grad(): + indicator = self.embeddings(indicator) + tok_mask = ~torch.isnan(indicator) + tok_mask = tok_mask[:, :, 0] + else: + tok_mask = None + + encoder_outputs = self.encoder( + embedding_output, + tok_mask=tok_mask, + head_mask=head_mask, + ) + sequence_output = encoder_outputs + sequence_output = self.layernorm(sequence_output) + + pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2 + + return ( + BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ), + tok_mask, + ) + +class ASTMLPHead(nn.Module): + def __init__(self, config: ASTConfig, device, dtype, operations): + super().__init__() + self.layernorm = operations.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.dense = operations.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.layernorm(hidden_state) + hidden_state = self.dense(hidden_state) + return hidden_state + +class RandInitPositionalEncoding(nn.Module): + def __init__(self, block_shape: list, n_embd: int, device = None, dtype = None,): + super().__init__() + self.block_shape = block_shape + self.n_embd = n_embd + self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd, device=device, dtype=dtype)) + + def forward(self, token_embeddings): + return token_embeddings + self.pos_emb + + +class GlobalTransformer(torch.nn.Module): + def __init__( + self, + tok_pdrop=0.0, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + n_layer=3, + n_head=8, + n_embd=768, + pos_emb_block_shape=[ + 198, + ], + n_off_head_out=21, + device = None, dtype = None, operations = None + ) -> None: + super().__init__() + + factory_kwargs = {"device":device, "dtype": dtype} + self.config = Config( + embd_pdrop=embd_pdrop, + resid_pdrop=resid_pdrop, + attn_pdrop=attn_pdrop, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + ) + # input norm + self.vis_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs) + self.aud_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs) + # aux tokens + self.OFF_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs)) + self.MOD_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs)) + # whole token dropout + self.tok_pdrop = tok_pdrop + self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) + self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) + # maybe add pos emb + self.pos_emb_cfg = RandInitPositionalEncoding( + block_shape=pos_emb_block_shape, + n_embd=n_embd, + ) + # the stem + self.drop = torch.nn.Dropout(embd_pdrop) + self.blocks = operations.Sequential(*[Block(self.config, operations=operations, **factory_kwargs) for _ in range(n_layer)]) + # pre-output norm + self.ln_f = operations.LayerNorm(n_embd) + # maybe add a head + self.off_head = operations.Linear(in_features=n_embd, out_features=n_off_head_out) + + def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): + B, Sv, D = v.shape + B, Sa, D = a.shape + + off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) + mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) + + v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) + + if self.tok_pdrop > 0: + v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) + + x = torch.cat((off_tok, v, mod_tok, a), dim=1) + if hasattr(self, "pos_emb_cfg"): + x = self.pos_emb_cfg(x) + + x = self.drop(x) + x = self.blocks(x) + x = self.ln_f(x) + + if attempt_to_apply_heads and hasattr(self, "off_head"): + x = self.off_head(x[:, 0, :]) + return x + + +class SelfAttention(nn.Module): + + def __init__(self, config, device, dtype, operations): + super().__init__() + + self.key = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype) + self.query = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype) + self.value = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype) + + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + + self.proj = operations.Linear(config.n_embd, config.n_embd, device=device, dtype=dtype) + self.n_head = config.n_head + + def forward(self, x): + B, T, C = x.size() + + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + y = optimized_attention(q, k, v, self.n_head, skip_reshape=True) + + y = self.resid_drop(self.proj(y)) + + return y + + +class Block(nn.Module): + def __init__(self, config, device, dtype, operations): + super().__init__() + factory_kwargs = {"device":device, "dtype":dtype} + self.ln1 = operations.LayerNorm(config.n_embd, **factory_kwargs) + self.ln2 = operations.LayerNorm(config.n_embd, **factory_kwargs) + self.attn = SelfAttention(config, device, dtype, operations) + self.mlp = nn.Sequential( + operations.Linear(config.n_embd, 4 * config.n_embd, **factory_kwargs), + nn.GELU(), + operations.Linear(4 * config.n_embd, config.n_embd, **factory_kwargs), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class Synchformer(nn.Module): + + def __init__(self, device, dtype, operations): + super().__init__() + + factory_kwargs = {"device":device, "dtype":dtype} + + self.vfeat_extractor = MotionFormer(operations = operations, **factory_kwargs) + self.afeat_extractor = AST( + operations = operations, + max_spec_t = 66, + factorize_freq_time = True, + **factory_kwargs + ) + + self.vproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs) + self.aproj = operations.Linear(in_features=768, out_features=768, **factory_kwargs) + self.transformer = GlobalTransformer( + tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768, operations=operations, **factory_kwargs + ) + + def forward(self, vis): + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + vis = self.vfeat_extractor(vis) + return vis + + def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): + vis = self.vproj(vis) + aud = self.aproj(aud) + + B, S, tv, D = vis.shape + B, S, ta, D = aud.shape + vis = vis.view(B, S * tv, D) + aud = aud.view(B, S * ta, D) + + logits = self.transformer(vis, aud) + + return logits + + def extract_vfeats(self, vis): + return self.vfeat_extractor(vis.permute(0, 1, 3, 2, 4, 5)) + + def extract_afeats(self, aud): + B, S, _, Fa, Ta = aud.shape + aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) + aud, _ = self.afeat_extractor(aud) + return aud diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py new file mode 100644 index 000000000..7c4057072 --- /dev/null +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -0,0 +1,85 @@ +import torch +import numpy as np +from typing import List +from einops import rearrange +from torchvision.transforms import v2 + +from comfy.ldm.hunyuan_foley.syncformer import Synchformer +from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder + +import comfy.ops +ops = comfy.ops.disable_weight_init + +class DAC(torch.nn.Module): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops) + + self.decoder = DACDecoder( + latent_dim, + decoder_dim, + decoder_rates, + operations = ops + ) + self.sample_rate = sample_rate + + + def decode(self, z: torch.Tensor): + return self.decoder(z) + + def forward(self): + pass + +class FoleyVae(torch.nn.Module): + def __init__(self): + self.dac = DAC() + self.syncformer = Synchformer(None, None, operations = ops) + self.syncformer_preprocess = v2.Compose( + [ + v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(224), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + def decode(self, x, vae_options = {}): + return self.dac.decode(x) + def encode(self, x): + return self.syncformer(x) + + def video_encoding(self, video, step: int): + + if not isinstance(video, torch.Tensor): + video = torch.from_numpy(video).permute(0, 3, 1, 2) + + video = self.syncformer_preprocess(video).unsqueeze(0) + seg_len = 16 + t = video.size(1) + nseg = max(0, (t - seg_len) // step + 1) + clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)] + data = torch.stack(clips, dim=1) + data = rearrange(data, "b s t c h w -> (b s) 1 t c h w") + + return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0)) diff --git a/comfy/ldm/hydit/posemb_layers.py b/comfy/ldm/hydit/posemb_layers.py index dcb41a713..0c2085405 100644 --- a/comfy/ldm/hydit/posemb_layers.py +++ b/comfy/ldm/hydit/posemb_layers.py @@ -158,7 +158,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): return emb -def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, freq_scaling = 1.0): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -180,6 +180,10 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float if isinstance(pos, int): pos = np.arange(pos) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + + if freq_scaling != 1.0: + freqs *= freq_scaling + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 043df28df..7febfcb7f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -2,6 +2,7 @@ import math import sys import torch +from torch import Tensor import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat @@ -1032,4 +1033,162 @@ class SpatialVideoTransformer(SpatialTransformer): out = x + x_in return out +# comfyui implementation of nn.MultiheadAttention +class MultiheadAttentionComfyv(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + bias=True, + batch_first=False, + device=None, + dtype=None, + operations = None + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + + # to avoid pytorch checkpoint registeration + object.__setattr__(self, "_q_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs)) + object.__setattr__(self, "_k_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs)) + object.__setattr__(self, "_v_proj", operations.Linear(embed_dim, embed_dim, **factory_kwargs)) + + self.out_proj = operations.Linear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self.num_heads = num_heads + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + self.embed_dim = embed_dim + + # overwriting state dict loading to convert in_proj_weight/bias -> self._q_proj/_k_proj/_v_proj + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + def pop_key(k): + return state_dict.pop(k) if k in state_dict else None + + in_proj_w_key = prefix + "in_proj_weight" + in_proj_b_key = prefix + "in_proj_bias" + + weight = pop_key(in_proj_w_key) + if weight is not None: + q_w, k_w, v_w = torch.chunk(weight, 3) + self._q_proj.weight.data.copy_(q_w.to(self._q_proj.weight.device, dtype=self._q_proj.weight.dtype)) + self._k_proj.weight.data.copy_(k_w.to(self._k_proj.weight.device, dtype=self._k_proj.weight.dtype)) + self._v_proj.weight.data.copy_(v_w.to(self._v_proj.weight.device, dtype=self._v_proj.weight.dtype)) + + bias = pop_key(in_proj_b_key) + if bias is not None: + q_b, k_b, v_b = torch.chunk(bias, 3) + if getattr(self._q_proj, "bias", None) is not None: + self._q_proj.bias.data.copy_(q_b.to(self._q_proj.bias.device, dtype=self._q_proj.bias.dtype)) + self._k_proj.bias.data.copy_(k_b.to(self._k_proj.bias.device, dtype=self._k_proj.bias.dtype)) + self._v_proj.bias.data.copy_(v_b.to(self._v_proj.bias.device, dtype=self._v_proj.bias.dtype)) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward(self, src, attn_mask = None, key_padding_mask = None): + + q = self._q_proj(src) + k = self._k_proj(src) + v = self._v_proj(src) + + output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask) + return self.out_proj(output) + +# comfyui implementation of nn.TransformerEncoderLayer +class TransformerEncoderComfyv(nn.Module): + def __init__(self, + d_model, nhead, dim_feedforward, + norm_first = False, + layer_norm_eps: float = 1e-5, + bias: bool = True, + activation = F.relu, + device = None, + dtype = None, operations = None, **kwargs): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.linear1 = operations.Linear(d_model, dim_feedforward, **factory_kwargs) + self.linear2 = operations.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + self.norm2 = operations.LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) + + self.activation = activation + + self.self_attn = MultiheadAttentionComfyv( + embed_dim = d_model, + num_heads = nhead, + bias = bias, + operations = operations, + **factory_kwargs + ) + + def forward(self, src, src_key_padding_mask = None, src_mask = None): + src_key_padding_mask = F._canonical_mask( + mask=src_key_padding_mask, + mask_name="src_key_padding_mask", + other_type=F._none_or_dtype(src_mask), + other_name="src_mask", + target_type=src.dtype, + ) + + src_mask = F._canonical_mask( + mask=src_mask, + mask_name="src_mask", + other_type=None, + other_name="", + target_type=src.dtype, + check_other=False, + ) + + x = src + if self.norm_first: + x = x + self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask + ) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1( + x + + self._sa_block(x, src_mask, src_key_padding_mask) + ) + x = self.norm2(x + self._ff_block(x)) + + return x + + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + return x + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + return self.linear2(self.activation(self.linear1(x))) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3cdfbbf7e..939036e9c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -46,6 +46,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.higgsv2.model import comfy.ldm.qwen_image.model +import comfy.ldm.hunyuan_foley.model import comfy.model_management import comfy.patcher_extension @@ -1350,6 +1351,10 @@ class ACEStep(BaseModel): out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype)) out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out + +class HunyuanFoley(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.hunyuan_foley.model.HunyuanVideoFoley): + super().__init__(model_config, model_type, device, unet_model) class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 52f9995fe..cfe56c4b8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -384,6 +384,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] return dit_config + + if '{}triple_blocks.17.audio_cross_q.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Foley + return {} if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape diff --git a/comfy/sd.py b/comfy/sd.py index c68106826..975c6ea96 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.hunyuan_foley.vae import yaml import math import os @@ -468,6 +469,12 @@ class VAE: self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # Hunyuan Foley + elif "syncformer.afeat_extractor.ast.encoder.layer.11.attention.attention.key.weight" in sd: + self.latent_dim = 128 + self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae() + # TODO + self.memory_used_encode = lambda shape, dtype: shape[0] * model_management.dtype_size(dtype) elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 25e798f4b..c5c3276bd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -4,6 +4,7 @@ from . import utils from . import sd1_clip from . import sdxl_clip +import comfy.clap_model import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 @@ -1266,6 +1267,21 @@ class Omnigen2(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) + +class HunyuanFoley(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan_foley", + } + + latent_format = latent_formats.HunyuanFoley + supported_inference_dtypes = [torch.bfloat16, torch.float32] + vae_key_prefix = ["dac."] + text_encoder_key_prefix = ["clap."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.HunyuanFoley(self, device=device) + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.clap_model.ClapLargeTokenizer, comfy.clap_model.ClapTextEncoderModel) class QwenImage(supported_models_base.BASE): unet_config = { diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index f646504c8..4a8b357b3 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -89,7 +89,7 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") - def get_duration(self) -> float: + def get_duration(self, retrun_frames = False) -> float: """ Returns the duration of the video in seconds. @@ -100,14 +100,16 @@ class VideoFromFile(VideoInput): self.__file.seek(0) with av.open(self.__file, mode="r") as container: if container.duration is not None: - return float(container.duration / av.time_base) + return float(container.duration / av.time_base) if not retrun_frames\ + else float(container.duration / av.time_base), float(container.duration) # Fallback: calculate from frame count and frame rate video_stream = next( (s for s in container.streams if s.type == "video"), None ) if video_stream and video_stream.frames and video_stream.average_rate: - return float(video_stream.frames / video_stream.average_rate) + return float(video_stream.frames / video_stream.average_rate) if not retrun_frames\ + else float(video_stream.frames / video_stream.average_rate), float(video_stream.frames) # Last resort: decode frames to count them if video_stream and video_stream.average_rate: @@ -117,7 +119,8 @@ class VideoFromFile(VideoInput): for _ in packet.decode(): frame_count += 1 if frame_count > 0: - return float(frame_count / video_stream.average_rate) + return float(frame_count / video_stream.average_rate) if not retrun_frames\ + else float(frame_count / video_stream.average_rate), float(frame_count) raise ValueError(f"Could not determine duration for file '{self.__file}'") diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py new file mode 100644 index 000000000..b76ad4aa1 --- /dev/null +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -0,0 +1,53 @@ +import torch +import comfy.model_management + +class EmptyLatentHunyuanFoley: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "length": ("INT", {"default": 12, "min": 1, "max": 15, "tooltip": "The length of the audio. The same length as the video."}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent audios in the batch."}), + }, + "optional": {"video": ("VIDEO")} + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/audio" + + def generate(self, length, batch_size, video = None): + if video is not None: + _, length = video.get_duration(return_frames = True) + length /= 25 + shape = (batch_size, 128, int(50 * length)) + latent = torch.randn(shape, device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan_foley"}, ) + +class HunyuanFoleyConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"video_encoding_siglip": ("CONDITIONING",), + "video_encoding_synchformer": ("CONDITIONING",), + "text_encoding": ("CONDITIONING",) + }, + } + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, video_encoding_1, video_encoding_2, text_encoding): + embeds = torch.cat([video_encoding_1, video_encoding_2, text_encoding], dim = 0) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + +NODE_CLASS_MAPPINGS = { + "HunyuanFoleyConditioning": HunyuanFoleyConditioning, + "EmptyLatentHunyuanFoley": EmptyLatentHunyuanFoley, +} diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 69fabb12e..a7c0f19e9 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -1,10 +1,12 @@ from __future__ import annotations import os +import io import av import torch import folder_paths import json +import numpy as np from typing import Optional from typing_extensions import override from fractions import Fraction @@ -14,6 +16,116 @@ from comfy_api.util import VideoCodec, VideoComponents, VideoContainer from comfy_api.latest import ComfyExtension, io, ui from comfy.cli_args import args +class EncodeVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EncodeVideo", + display_name="Encode Video", + category="image/video", + description="Encode a video using an image encoder.", + inputs=[ + io.Video.Input("video", tooltip="The video to be encoded."), + io.Int.Input( + "processing_batch_size", default=-1, min=-1, + tooltip=( + "Number of frames/segments to process at a time during encoding.\n" + "-1 means process all at once. Smaller values reduce GPU memory usage." + ), + ), + io.Int.Input("step_size", default=8, min=1, max=32, + tooltip=( + "Stride (in frames) between the start of consecutive segments.\n" + "Smaller step = more overlap and smoother temporal coverage " + "but higher compute cost. Larger step = faster but may miss detail." + ), + ), + io.Vae.Input("vae", optional=True), + io.ClipVision.Input("clip_vision", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="encoded_video"), + ], + ) + + @classmethod + def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): + b, t, c, h, w = video.shape + batch_size = b * t + + if vae is None and clip_vision is None: + raise ValueError("Must either have vae or clip_vision.") + vae = vae if vae is not None else clip_vision + + if hasattr(vae.first_stage_model, "video_encoding"): + data, num_segments, output_fn = vae.video_encoding(video, step_size) + batch_size = b * num_segments + else: + data = video.view(batch_size, c, h, w) + output_fn = lambda x: x.view(b, t, -1) + + if processing_batch_size != -1: + batch_size = processing_batch_size + + outputs = [] + total = data.shape[0] + for i in range(0, total, batch_size): + chunk = data[i : i + batch_size] + out = vae.encode(chunk) + outputs.append(out) + + output = torch.cat(outputs) + + return output_fn(output) + +class ResampleVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResampleVideo", + display_name="Resample Video", + category="image/video", + inputs = [ + io.Video.Input("video"), + io.Int.Input("target_fps") + ], + outputs=[io.Image.Output(display_name="images")] + ) + @classmethod + def execute(cls, container: av.container.InputContainer, target_fps: int): + # doesn't support upsampling + + stream = container.streams.video[0] + frames = [] + + src_rate = stream.average_rate or stream.guessed_rate + src_fps = float(src_rate) if src_rate else None + + # yield original frames if asked for upsampling or src is unknown + if src_fps is None or target_fps > src_fps: + for packet in container.demux(stream): + for frame in packet.decode(): + arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 + frames.append(arr) + return torch.stack(frames) + + stream.thread_type = "AUTO" + + next_time = 0.0 + step = 1.0 / target_fps + + for packet in container.demux(stream): + for frame in packet.decode(): + if frame.time is None: + continue + t = frame.time + while t >= next_time: + arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 + frames.append(arr) + next_time += step + + return torch.stack(frames) + class SaveWEBM(io.ComfyNode): @classmethod def define_schema(cls): @@ -212,6 +324,8 @@ class VideoExtension(ComfyExtension): CreateVideo, GetVideoComponents, LoadVideo, + EncodeVideo, + ResampleVideo ] async def comfy_entrypoint() -> VideoExtension: diff --git a/nodes.py b/nodes.py index 35c83500e..9a0251d37 100644 --- a/nodes.py +++ b/nodes.py @@ -2378,6 +2378,7 @@ async def init_builtin_extra_nodes(): "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", + "nodes_hunyuan_foley.py" ] import_failed = []