mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
203 lines
9.0 KiB
Python
203 lines
9.0 KiB
Python
from comfy.ldm.cosmos.predict2 import MiniTrainDIT
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
|
|
cos = cos.unsqueeze(unsqueeze_dim)
|
|
sin = sin.unsqueeze(unsqueeze_dim)
|
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
return x_embed
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, head_dim):
|
|
super().__init__()
|
|
self.rope_theta = 10000
|
|
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids):
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
|
|
inner_dim = head_dim * n_heads
|
|
self.n_heads = n_heads
|
|
self.head_dim = head_dim
|
|
self.query_dim = query_dim
|
|
self.context_dim = context_dim
|
|
|
|
self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype)
|
|
|
|
def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
|
|
context = x if context is None else context
|
|
input_shape = x.shape[:-1]
|
|
q_shape = (*input_shape, self.n_heads, self.head_dim)
|
|
context_shape = context.shape[:-1]
|
|
kv_shape = (*context_shape, self.n_heads, self.head_dim)
|
|
|
|
query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
|
|
key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
|
|
value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
|
|
|
|
if position_embeddings is not None:
|
|
assert position_embeddings_context is not None
|
|
cos, sin = position_embeddings
|
|
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
|
cos, sin = position_embeddings_context
|
|
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
|
|
|
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
return attn_output
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.zeros_(self.o_proj.weight)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None):
|
|
super().__init__()
|
|
self.use_self_attn = use_self_attn
|
|
|
|
if self.use_self_attn:
|
|
self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.self_attn = Attention(
|
|
query_dim=model_dim,
|
|
context_dim=model_dim,
|
|
n_heads=num_heads,
|
|
head_dim=model_dim//num_heads,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations,
|
|
)
|
|
|
|
self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.cross_attn = Attention(
|
|
query_dim=model_dim,
|
|
context_dim=source_dim,
|
|
n_heads=num_heads,
|
|
head_dim=model_dim//num_heads,
|
|
device=device,
|
|
dtype=dtype,
|
|
operations=operations,
|
|
)
|
|
|
|
self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.mlp = nn.Sequential(
|
|
operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype),
|
|
nn.GELU(),
|
|
operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None):
|
|
if self.use_self_attn:
|
|
normed = self.norm_self_attn(x)
|
|
attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings)
|
|
x = x + attn_out
|
|
|
|
normed = self.norm_cross_attn(x)
|
|
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
|
x = x + attn_out
|
|
|
|
x = x + self.mlp(self.norm_mlp(x))
|
|
return x
|
|
|
|
def init_weights(self):
|
|
torch.nn.init.zeros_(self.mlp[2].weight)
|
|
self.cross_attn.init_weights()
|
|
|
|
|
|
class LLMAdapter(nn.Module):
|
|
def __init__(
|
|
self,
|
|
source_dim=1024,
|
|
target_dim=1024,
|
|
model_dim=1024,
|
|
num_layers=6,
|
|
num_heads=16,
|
|
use_self_attn=True,
|
|
layer_norm=False,
|
|
device=None,
|
|
dtype=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype)
|
|
if model_dim != target_dim:
|
|
self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype)
|
|
else:
|
|
self.in_proj = nn.Identity()
|
|
self.rotary_emb = RotaryEmbedding(model_dim//num_heads)
|
|
self.blocks = nn.ModuleList([
|
|
TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)
|
|
])
|
|
self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype)
|
|
self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype)
|
|
|
|
def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
|
if target_attention_mask is not None:
|
|
target_attention_mask = target_attention_mask.to(torch.bool)
|
|
if target_attention_mask.ndim == 2:
|
|
target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
|
|
|
|
if source_attention_mask is not None:
|
|
source_attention_mask = source_attention_mask.to(torch.bool)
|
|
if source_attention_mask.ndim == 2:
|
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
|
|
|
x = self.in_proj(self.embed(target_input_ids))
|
|
context = source_hidden_states
|
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
|
position_embeddings = self.rotary_emb(x, position_ids)
|
|
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
|
for block in self.blocks:
|
|
x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context)
|
|
return self.norm(self.out_proj(x))
|
|
|
|
|
|
class Anima(MiniTrainDIT):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
|
|
|
def preprocess_text_embeds(self, text_embeds, text_ids):
|
|
if text_ids is not None:
|
|
return self.llm_adapter(text_embeds, text_ids)
|
|
else:
|
|
return text_embeds
|