diff --git a/comfy/ldm/twinflow/model.py b/comfy/ldm/twinflow/model.py new file mode 100644 index 000000000..a99bcd06e --- /dev/null +++ b/comfy/ldm/twinflow/model.py @@ -0,0 +1,700 @@ +""" +TwinFlow-Z-Image custom model architecture for ComfyUI. +Based on the Lumina-Image 2.0 / Z-Image architecture. +Supports the unique dual timestep embedding architecture of TwinFlow. +""" + +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import comfy.ldm.common_dit + +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope +import comfy.patcher_extension + + +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + + +def modulate(x, scale): + return x * (1 + scale.unsqueeze(1)) + + +class JointAttention(nn.Module): + """Multi-head attention module with combined QKV weights.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + out_bias: bool = False, + operation_settings={}, + ): + super().__init__() + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads + self.n_local_kv_heads = self.n_kv_heads + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.qkv = operation_settings.get("operations").Linear( + dim, + (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.out = operation_settings.get("operations").Linear( + n_heads * self.head_dim, + dim, + bias=out_bias, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + if qk_norm: + self.q_norm = operation_settings.get("operations").RMSNorm( + self.head_dim, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.k_norm = operation_settings.get("operations").RMSNorm( + self.head_dim, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + else: + self.q_norm = self.k_norm = nn.Identity() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + transformer_options={}, + ) -> torch.Tensor: + bsz, seqlen, _ = x.shape + + xq, xk, xv = torch.split( + self.qkv(x), + [ + self.n_local_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + self.n_local_kv_heads * self.head_dim, + ], + dim=-1, + ) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + xq, xk = apply_rope(xq, xk, freqs_cis) + + n_rep = self.n_local_heads // self.n_local_kv_heads + if n_rep >= 1: + xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + output = optimized_attention_masked( + xq.movedim(1, 2), + xk.movedim(1, 2), + xv.movedim(1, 2), + self.n_local_heads, + x_mask, + skip_reshape=True, + transformer_options=transformer_options, + ) + + return self.out(output) + + +class FeedForward(nn.Module): + """Feed-forward module with SiLU gating.""" + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + operation_settings={}, + ): + super().__init__() + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = operation_settings.get("operations").Linear( + dim, + hidden_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.w2 = operation_settings.get("operations").Linear( + hidden_dim, + dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.w3 = operation_settings.get("operations").Linear( + dim, + hidden_dim, + bias=False, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + def _forward_silu_gating(self, x1, x3): + return clamp_fp16(F.silu(x1) * x3) + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +class TwinFlowTransformerBlock(nn.Module): + """Transformer block with adaLN modulation for TwinFlow.""" + + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=False, + z_image_modulation=False, + attn_out_bias=False, + operation_settings={}, + ) -> None: + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + self.attention = JointAttention( + dim, + n_heads, + n_kv_heads, + qk_norm, + out_bias=attn_out_bias, + operation_settings=operation_settings, + ) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + operation_settings=operation_settings, + ) + self.layer_id = layer_id + self.attention_norm1 = operation_settings.get("operations").RMSNorm( + dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.ffn_norm1 = operation_settings.get("operations").RMSNorm( + dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.attention_norm2 = operation_settings.get("operations").RMSNorm( + dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.ffn_norm2 = operation_settings.get("operations").RMSNorm( + dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + transformer_options={}, + ): + if adaln_input is not None: + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + clamp_fp16( + self.attention( + modulate(self.attention_norm1(x), scale_msa), + x_mask, + freqs_cis, + transformer_options=transformer_options, + ) + ) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + clamp_fp16( + self.feed_forward( + modulate(self.ffn_norm1(x), scale_mlp), + ) + ) + ) + else: + x = x + self.attention_norm2( + clamp_fp16( + self.attention( + self.attention_norm1(x), + x_mask, + freqs_cis, + transformer_options=transformer_options, + ) + ) + ) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + return x + + +class FinalLayer(nn.Module): + """Final layer with LayerNorm and output projection.""" + + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + z_image_modulation=False, + operation_settings={}, + ): + super().__init__() + self.norm_final = operation_settings.get("operations").LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + self.linear = operation_settings.get("operations").Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + min_mod = 256 if z_image_modulation else 1024 + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(hidden_size, min_mod), + hidden_size, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor): + scale = self.adaLN_modulation(c) + x = modulate(self.norm_final(x), scale) + x = self.linear(x) + return x + + +class TwinFlowZImageTransformer(nn.Module): + """ + TwinFlow-Z-Image transformer model. + + This custom architecture handles dual timestep embeddings + (t_embedder and t_embedder_2), the primary TwinFlow distinction. + """ + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + dim: int = 3840, + n_layers: int = 30, + n_refiner_layers: int = 2, + n_heads: int = 30, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: float = 2.6666666666666665, + norm_eps: float = 1e-5, + qk_norm: bool = True, + cap_feat_dim: int = 2560, + axes_dims: List[int] = (32, 48, 48), + axes_lens: List[int] = (1, 1536, 512, 512), + rope_theta: float = 256.0, + z_image_modulation: bool = True, + time_scale: float = 1000.0, + pad_tokens_multiple=None, + clip_text_dim=None, + image_model=None, + device=None, + dtype=None, + operations=None, + **kwargs, + ) -> None: + super().__init__() + self.dtype = dtype + operation_settings = { + "operations": operations, + "device": device, + "dtype": dtype, + } + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple + + self.x_embedder = operation_settings.get("operations").Linear( + in_features=patch_size * patch_size * in_channels, + out_features=dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ) + + self.t_embedder = TimestepEmbedder( + min(dim, 1024), + output_size=256 if z_image_modulation else None, + **operation_settings, + ) + self.t_embedder_2 = TimestepEmbedder( + min(dim, 1024), + output_size=256 if z_image_modulation else None, + **operation_settings, + ) + + self.noise_refiner = nn.ModuleList( + [ + TwinFlowTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=z_image_modulation, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + TwinFlowTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.cap_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm( + cap_feat_dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + operation_settings.get("operations").Linear( + cap_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + self.clip_text_pooled_proj = None + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm( + clip_text_dim, + eps=norm_eps, + elementwise_affine=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + + self.layers = nn.ModuleList( + [ + TwinFlowTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_layers) + ] + ) + + self.final_layer = FinalLayer( + dim, + patch_size, + self.out_channels, + z_image_modulation=z_image_modulation, + operation_settings=operation_settings, + ) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + + assert (dim // n_heads) == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) + self.dim = dim + self.n_heads = n_heads + + def unpatchify( + self, + x: torch.Tensor, + img_size: List[Tuple[int, int]], + cap_size: List[int], + return_tensor=False, + ) -> List[torch.Tensor]: + pH = pW = self.patch_size + imgs = [] + for i in range(x.size(0)): + H, W = img_size[i] + begin = cap_size[i] + end = begin + (H // pH) * (W // pW) + imgs.append( + x[i][begin:end] + .view(H // pH, W // pW, pH, pW, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + + if return_tensor: + imgs = torch.stack(imgs, dim=0) + return imgs + + def patchify_and_embed( + self, + x: List[torch.Tensor] | torch.Tensor, + cap_feats: torch.Tensor, + cap_mask: torch.Tensor, + t: torch.Tensor, + num_tokens, + transformer_options={}, + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + + B, C, H, W = x.shape + x = self.x_embedder( + x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2) + ) + + rope_options = transformer_options.get("rope_options", {}) + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = ( + torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start + ).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = ( + torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start + ).view(1, -1).repeat(H_tokens, 1).flatten() + + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat( + ( + x, + self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1), + ), + dim=1, + ) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) + + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat( + ( + cap_feats, + self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True) + .unsqueeze(0) + .repeat(cap_feats.shape[0], pad_extra, 1), + ), + dim=1, + ) + cap_pos_ids = torch.nn.functional.pad(cap_pos_ids, (0, 0, 0, pad_extra), value=0) + + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + + for layer in self.context_refiner: + cap_feats = layer( + cap_feats, + cap_mask, + freqs_cis[:, : cap_pos_ids.shape[1]], + transformer_options=transformer_options, + ) + + padded_img_mask = None + for _, layer in enumerate(self.noise_refiner): + x = layer( + x, + padded_img_mask, + freqs_cis[:, cap_pos_ids.shape[1] :], + t, + transformer_options=transformer_options, + ) + + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz + + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, + kwargs.get("transformer_options", {}), + ), + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + def _forward( + self, + x, + timesteps, + context, + num_tokens, + attention_mask=None, + transformer_options={}, + **kwargs, + ): + t = 1.0 - timesteps + + t_emb = self.t_embedder(t * self.time_scale, dtype=x.dtype) + t_emb_2 = self.t_embedder_2(t * self.time_scale, dtype=x.dtype) + + adaln_input = t_emb + t_emb_2 + + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + cap_feats = self.cap_embedder(cap_feats) + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = torch.cat((t_emb, pooled), dim=-1) + adaln_input = self.clip_text_pooled_proj(adaln_input) + + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, + cap_feats, + cap_mask, + adaln_input, + num_tokens, + transformer_options=transformer_options, + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + + img = self.final_layer(img, adaln_input) + img = self.unpatchify( + img, + img_size, + cap_size, + return_tensor=isinstance(x, torch.Tensor), + )[:, :, :h, :w] + + return -img