mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 22:00:49 +08:00
306 lines
12 KiB
Python
306 lines
12 KiB
Python
import math
|
|
from typing import Optional
|
|
|
|
import comfy.ldm.common_dit
|
|
import torch
|
|
from comfy.ldm.lightricks.model import (
|
|
CrossAttention,
|
|
FeedForward,
|
|
generate_freq_grid_np,
|
|
interleaved_freqs_cis,
|
|
split_freqs_cis,
|
|
)
|
|
from torch import nn
|
|
|
|
|
|
class BasicTransformerBlock1D(nn.Module):
|
|
r"""
|
|
A basic Transformer block.
|
|
|
|
Parameters:
|
|
|
|
dim (`int`): The number of channels in the input and output.
|
|
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
|
attention_head_dim (`int`): The number of channels in each head.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
|
attention_bias (:
|
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
|
upcast_attention (`bool`, *optional*):
|
|
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
|
Whether to use learnable elementwise affine parameters for normalization.
|
|
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
|
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
|
|
qk_norm (`str`, *optional*, defaults to None):
|
|
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
|
final_dropout (`bool` *optional*, defaults to False):
|
|
Whether to apply a final dropout after the last feed-forward layer.
|
|
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
|
|
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
|
|
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
|
|
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
|
|
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
n_heads,
|
|
d_head,
|
|
context_dim=None,
|
|
attn_precision=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
# Define 3 blocks. Each block has its own normalization layer.
|
|
# 1. Self-Attn
|
|
self.attn1 = CrossAttention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
context_dim=None,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# 3. Feed-forward
|
|
self.ff = FeedForward(
|
|
dim,
|
|
dim_out=dim,
|
|
glu=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
|
|
|
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
|
|
|
# 1. Normalization Before Self-Attention
|
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
|
|
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
|
|
|
# 2. Self-Attention
|
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
|
|
|
hidden_states = attn_output + hidden_states
|
|
if hidden_states.ndim == 4:
|
|
hidden_states = hidden_states.squeeze(1)
|
|
|
|
# 3. Normalization before Feed-Forward
|
|
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
|
|
|
# 4. Feed-forward
|
|
ff_output = self.ff(norm_hidden_states)
|
|
|
|
hidden_states = ff_output + hidden_states
|
|
if hidden_states.ndim == 4:
|
|
hidden_states = hidden_states.squeeze(1)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Embeddings1DConnector(nn.Module):
|
|
_supports_gradient_checkpointing = True
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=128,
|
|
cross_attention_dim=2048,
|
|
attention_head_dim=128,
|
|
num_attention_heads=30,
|
|
num_layers=2,
|
|
positional_embedding_theta=10000.0,
|
|
positional_embedding_max_pos=[4096],
|
|
causal_temporal_positioning=False,
|
|
num_learnable_registers: Optional[int] = 128,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
split_rope=False,
|
|
double_precision_rope=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.out_channels = in_channels
|
|
self.num_attention_heads = num_attention_heads
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
self.causal_temporal_positioning = causal_temporal_positioning
|
|
self.positional_embedding_theta = positional_embedding_theta
|
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
|
self.split_rope = split_rope
|
|
self.double_precision_rope = double_precision_rope
|
|
self.transformer_1d_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock1D(
|
|
self.inner_dim,
|
|
num_attention_heads,
|
|
attention_head_dim,
|
|
context_dim=cross_attention_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
self.num_learnable_registers = num_learnable_registers
|
|
if self.num_learnable_registers:
|
|
self.learnable_registers = nn.Parameter(
|
|
torch.rand(
|
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
|
)
|
|
* 2.0
|
|
- 1.0
|
|
)
|
|
|
|
def get_fractional_positions(self, indices_grid):
|
|
fractional_positions = torch.stack(
|
|
[
|
|
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
|
for i in range(1)
|
|
],
|
|
dim=-1,
|
|
)
|
|
return fractional_positions
|
|
|
|
def precompute_freqs(self, indices_grid, spacing):
|
|
source_dtype = indices_grid.dtype
|
|
dtype = (
|
|
torch.float32
|
|
if source_dtype in (torch.bfloat16, torch.float16)
|
|
else source_dtype
|
|
)
|
|
|
|
fractional_positions = self.get_fractional_positions(indices_grid)
|
|
indices = (
|
|
generate_freq_grid_np(
|
|
self.positional_embedding_theta,
|
|
indices_grid.shape[1],
|
|
self.inner_dim,
|
|
)
|
|
if self.double_precision_rope
|
|
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
|
|
).to(device=fractional_positions.device)
|
|
|
|
if spacing == "exp_2":
|
|
freqs = (
|
|
(indices * fractional_positions.unsqueeze(-1))
|
|
.transpose(-1, -2)
|
|
.flatten(2)
|
|
)
|
|
else:
|
|
freqs = (
|
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
.transpose(-1, -2)
|
|
.flatten(2)
|
|
)
|
|
return freqs
|
|
|
|
def generate_freq_grid(self, spacing, dtype, device):
|
|
dim = self.inner_dim
|
|
theta = self.positional_embedding_theta
|
|
n_pos_dims = 1
|
|
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
|
|
start = 1
|
|
end = theta
|
|
|
|
if spacing == "exp":
|
|
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
|
|
indices = indices.to(dtype=dtype, device=device)
|
|
elif spacing == "exp_2":
|
|
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
|
|
indices = indices.to(dtype=dtype)
|
|
elif spacing == "linear":
|
|
indices = torch.linspace(
|
|
start, end, dim // n_elem, device=device, dtype=dtype
|
|
)
|
|
elif spacing == "sqrt":
|
|
indices = torch.linspace(
|
|
start**2, end**2, dim // n_elem, device=device, dtype=dtype
|
|
).sqrt()
|
|
|
|
indices = indices * math.pi / 2
|
|
|
|
return indices
|
|
|
|
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
|
dim = self.inner_dim
|
|
n_elem = 2 # 2 because of cos and sin
|
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
|
if self.split_rope:
|
|
expected_freqs = dim // 2
|
|
current_freqs = freqs.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
cos_freq, sin_freq = split_freqs_cis(
|
|
freqs, pad_size, self.num_attention_heads
|
|
)
|
|
else:
|
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
|
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
"""
|
|
The [`Transformer2DModel`] forward method.
|
|
|
|
Args:
|
|
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
|
Input `hidden_states`.
|
|
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
|
attention_mask ( `torch.Tensor`, *optional*):
|
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
|
negative values to the attention scores corresponding to "discard" tokens.
|
|
Returns:
|
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
`tuple` where the first element is the sample tensor.
|
|
"""
|
|
# 1. Input
|
|
|
|
if self.num_learnable_registers:
|
|
num_registers_duplications = math.ceil(
|
|
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
|
)
|
|
learnable_registers = torch.tile(
|
|
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
|
)
|
|
|
|
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
|
|
|
|
indices_grid = torch.arange(
|
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
indices_grid = indices_grid[None, None, :]
|
|
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
|
|
|
# 2. Blocks
|
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
|
hidden_states = block(
|
|
hidden_states, attention_mask=attention_mask, pe=freqs_cis
|
|
)
|
|
|
|
# 3. Output
|
|
# if self.output_scale is not None:
|
|
# hidden_states = hidden_states / self.output_scale
|
|
|
|
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
|
|
|
|
return hidden_states, attention_mask
|