import math from typing import Optional import comfy.ldm.common_dit import torch from comfy.ldm.lightricks.model import ( CrossAttention, FeedForward, generate_freq_grid_np, interleaved_freqs_cis, split_freqs_cis, ) from torch import nn class BasicTransformerBlock1D(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. upcast_attention (`bool`, *optional*): Whether to upcast the attention computation to float32. This is useful for mixed precision training. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. qk_norm (`str`, *optional*, defaults to None): Set to 'layer_norm' or `rms_norm` to perform query and key normalization. final_dropout (`bool` *optional*, defaults to False): Whether to apply a final dropout after the last feed-forward layer. ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. """ def __init__( self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None, ): super().__init__() # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, dtype=dtype, device=device, operations=operations, ) # 3. Feed-forward self.ff = FeedForward( dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations, ) def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 1. Normalization Before Self-Attention norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) norm_hidden_states = norm_hidden_states.squeeze(1) # 2. Self-Attention attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 3. Normalization before Feed-Forward norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) # 4. Feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states class Embeddings1DConnector(nn.Module): _supports_gradient_checkpointing = True def __init__( self, in_channels=128, cross_attention_dim=2048, attention_head_dim=128, num_attention_heads=30, num_layers=2, positional_embedding_theta=10000.0, positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, dtype=None, device=None, operations=None, split_rope=False, double_precision_rope=False, **kwargs, ): super().__init__() self.dtype = dtype self.out_channels = in_channels self.num_attention_heads = num_attention_heads self.inner_dim = num_attention_heads * attention_head_dim self.causal_temporal_positioning = causal_temporal_positioning self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_max_pos = positional_embedding_max_pos self.split_rope = split_rope self.double_precision_rope = double_precision_rope self.transformer_1d_blocks = nn.ModuleList( [ BasicTransformerBlock1D( self.inner_dim, num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, dtype=dtype, device=device, operations=operations, ) for _ in range(num_layers) ] ) inner_dim = num_attention_heads * attention_head_dim self.num_learnable_registers = num_learnable_registers if self.num_learnable_registers: self.learnable_registers = nn.Parameter( torch.rand( self.num_learnable_registers, inner_dim, dtype=dtype, device=device ) * 2.0 - 1.0 ) def get_fractional_positions(self, indices_grid): fractional_positions = torch.stack( [ indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(1) ], dim=-1, ) return fractional_positions def precompute_freqs(self, indices_grid, spacing): source_dtype = indices_grid.dtype dtype = ( torch.float32 if source_dtype in (torch.bfloat16, torch.float16) else source_dtype ) fractional_positions = self.get_fractional_positions(indices_grid) indices = ( generate_freq_grid_np( self.positional_embedding_theta, indices_grid.shape[1], self.inner_dim, ) if self.double_precision_rope else self.generate_freq_grid(spacing, dtype, fractional_positions.device) ).to(device=fractional_positions.device) if spacing == "exp_2": freqs = ( (indices * fractional_positions.unsqueeze(-1)) .transpose(-1, -2) .flatten(2) ) else: freqs = ( (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) .transpose(-1, -2) .flatten(2) ) return freqs def generate_freq_grid(self, spacing, dtype, device): dim = self.inner_dim theta = self.positional_embedding_theta n_pos_dims = 1 n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 start = 1 end = theta if spacing == "exp": indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) indices = indices.to(dtype=dtype, device=device) elif spacing == "exp_2": indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) indices = indices.to(dtype=dtype) elif spacing == "linear": indices = torch.linspace( start, end, dim // n_elem, device=device, dtype=dtype ) elif spacing == "sqrt": indices = torch.linspace( start**2, end**2, dim // n_elem, device=device, dtype=dtype ).sqrt() indices = indices * math.pi / 2 return indices def precompute_freqs_cis(self, indices_grid, spacing="exp"): dim = self.inner_dim n_elem = 2 # 2 because of cos and sin freqs = self.precompute_freqs(indices_grid, spacing) if self.split_rope: expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs cos_freq, sin_freq = split_freqs_cis( freqs, pad_size, self.num_attention_heads ) else: cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input if self.num_learnable_registers: num_registers_duplications = math.ceil( max(1024, hidden_states.shape[1]) / self.num_learnable_registers ) learnable_registers = torch.tile( self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) ) hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) if attention_mask is not None: attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) indices_grid = torch.arange( hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device ) indices_grid = indices_grid[None, None, :] freqs_cis = self.precompute_freqs_cis(indices_grid) # 2. Blocks for block_idx, block in enumerate(self.transformer_1d_blocks): hidden_states = block( hidden_states, attention_mask=attention_mask, pe=freqs_cis ) # 3. Output # if self.output_scale is not None: # hidden_states = hidden_states / self.output_scale hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) return hidden_states, attention_mask