mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 21:37:40 +08:00
298 lines
14 KiB
Python
298 lines
14 KiB
Python
"""
|
|
The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model
|
|
consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims)
|
|
packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.patcher_extension
|
|
from comfy.ldm.lumina.model import FeedForward
|
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
|
from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis
|
|
|
|
# Per-token role indicators
|
|
SEQUENCE_PADDING_INDICATOR = -1
|
|
OUTPUT_IMAGE_INDICATOR = 2
|
|
LLM_TOKEN_INDICATOR = 3
|
|
# Image grid coordinates are offset so they never collide with text positions
|
|
IMAGE_POSITION_OFFSET = 65536
|
|
|
|
|
|
class Ideogram4Attention(nn.Module):
|
|
def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = hidden_size // num_heads
|
|
self.hidden_size = hidden_size
|
|
|
|
self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device)
|
|
self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, attn_mask, freqs_cis, transformer_options={}):
|
|
batch_size, seq_len, _ = x.shape
|
|
qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
|
q, k, v = qkv.unbind(dim=2)
|
|
|
|
q = self.norm_q(q)
|
|
k = self.norm_k(k)
|
|
|
|
# (B, heads, L, head_dim)
|
|
q = q.transpose(1, 2)
|
|
k = k.transpose(1, 2)
|
|
v = v.transpose(1, 2)
|
|
|
|
q, k = apply_rope(q, k, freqs_cis)
|
|
|
|
out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
|
|
return self.o(out)
|
|
|
|
|
|
class Ideogram4TransformerBlock(nn.Module):
|
|
def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
|
|
self.feed_forward = FeedForward(
|
|
dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None,
|
|
operation_settings={"operations": operations, "dtype": dtype, "device": device},
|
|
)
|
|
|
|
self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
|
|
self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}):
|
|
mod = self.adaln_modulation(adaln_input)
|
|
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1)
|
|
gate_msa = torch.tanh(gate_msa)
|
|
gate_mlp = torch.tanh(gate_mlp)
|
|
scale_msa = 1.0 + scale_msa
|
|
scale_mlp = 1.0 + scale_mlp
|
|
|
|
attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options)
|
|
x = x + gate_msa * self.attention_norm2(attn_out)
|
|
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
|
return x
|
|
|
|
|
|
def _sinusoidal_embedding(t, dim, scale=1e4):
|
|
t = t.to(torch.float32)
|
|
half = dim // 2
|
|
freq = math.log(scale) / (half - 1)
|
|
freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq)
|
|
emb = t.unsqueeze(-1) * freq
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
if dim % 2 == 1:
|
|
emb = F.pad(emb, (0, 1))
|
|
return emb
|
|
|
|
|
|
class Ideogram4EmbedScalar(nn.Module):
|
|
def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.range_min, self.range_max = input_range
|
|
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
|
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
x = x.to(torch.float32)
|
|
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
|
emb = _sinusoidal_embedding(scaled, self.dim)
|
|
emb = emb.to(self.mlp_in.weight.dtype)
|
|
emb = F.silu(self.mlp_in(emb))
|
|
return self.mlp_out(emb)
|
|
|
|
|
|
class Ideogram4FinalLayer(nn.Module):
|
|
def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device)
|
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
|
self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, c):
|
|
scale = 1.0 + self.adaln_modulation(F.silu(c))
|
|
return self.linear(self.norm_final(x) * scale)
|
|
|
|
|
|
class Ideogram4Transformer(nn.Module):
|
|
"""A single Ideogram 4 backbone operating on a packed token sequence."""
|
|
|
|
def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim,
|
|
in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.head_dim = emb_dim // num_heads
|
|
self.rope_theta = rope_theta
|
|
self.mrope_section = tuple(mrope_section)
|
|
|
|
self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device)
|
|
self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device)
|
|
self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations)
|
|
self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device)
|
|
|
|
self.layers = nn.ModuleList([
|
|
Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim,
|
|
dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}):
|
|
indicator = indicator.to(torch.long)
|
|
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1)
|
|
|
|
x = x * output_image_mask
|
|
h = self.input_proj(x) * output_image_mask
|
|
|
|
t_cond = self.t_embedding(t)
|
|
if t.dim() == 1:
|
|
t_cond = t_cond.unsqueeze(1)
|
|
adaln_input = F.silu(self.adaln_proj(t_cond))
|
|
|
|
# h is zero on the text rows (content lives only on image rows), add writes the text features in place
|
|
if llm_features is not None:
|
|
L_text = llm_features.shape[1]
|
|
text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1)
|
|
llm = self.llm_cond_norm(llm_features * text_mask)
|
|
llm = self.llm_cond_proj(llm) * text_mask
|
|
h[:, :L_text] = h[:, :L_text] + llm
|
|
|
|
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
|
|
|
|
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
|
freqs_cis = precompute_freqs_cis(
|
|
self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta,
|
|
rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device,
|
|
)
|
|
|
|
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
|
attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max)
|
|
|
|
for layer in self.layers:
|
|
h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
|
|
|
return self.final_layer(h, adaln_input)
|
|
|
|
|
|
class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
|
"""Ideogram 4 single-stream DiT.
|
|
|
|
Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``.
|
|
"""
|
|
|
|
def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288,
|
|
adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5,
|
|
dtype=None, device=None, operations=None, **kwargs):
|
|
emb_dim = num_attention_heads * attention_head_dim
|
|
super().__init__(
|
|
emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads,
|
|
intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels,
|
|
llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section,
|
|
norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
|
self.dtype = dtype
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels
|
|
# 128-dim token = patch (2x2) * ae_channels (32).
|
|
self.patch_size = 2
|
|
self.ae_channels = in_channels // (self.patch_size * self.patch_size)
|
|
|
|
def _img_to_tokens(self, x):
|
|
B, C, gh, gw = x.shape
|
|
x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw)
|
|
x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c)
|
|
return x.reshape(B, gh * gw, C)
|
|
|
|
def _tokens_to_img(self, tokens, gh, gw):
|
|
B = tokens.shape[0]
|
|
C = tokens.shape[-1]
|
|
x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels)
|
|
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
|
|
return x.reshape(B, C, gh, gw)
|
|
|
|
def _image_position_ids(self, gh, gw, device):
|
|
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
|
|
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
|
|
t_idx = torch.zeros_like(h_idx)
|
|
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
|
|
|
|
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
|
B = x_chunk.shape[0]
|
|
device = x_chunk.device
|
|
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
|
L_img = img_tokens.shape[1]
|
|
L_text = context_chunk.shape[1]
|
|
L = L_text + L_img
|
|
latent_dim = img_tokens.shape[-1]
|
|
|
|
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
|
|
x_full[:, L_text:] = img_tokens
|
|
|
|
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
|
|
img_pos = self._image_position_ids(gh, gw, device)
|
|
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
|
|
|
|
indicator = torch.empty(B, L, dtype=torch.long, device=device)
|
|
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
|
|
indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR
|
|
|
|
attn_mask = None
|
|
if attn_mask_chunk is not None:
|
|
segment_ids = torch.ones(B, L, dtype=torch.long, device=device)
|
|
pad = (attn_mask_chunk == 0)
|
|
segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR
|
|
indicator[:, :L_text][pad] = 0
|
|
# Block-diagonal mask from segment ids: (B, 1, L, L), True = attend.
|
|
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
|
|
|
|
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
|
|
transformer_options=transformer_options)
|
|
return self._tokens_to_img(out[:, L_text:], gh, gw)
|
|
|
|
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
|
B = x_chunk.shape[0]
|
|
device = x_chunk.device
|
|
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
|
L_img = img_tokens.shape[1]
|
|
|
|
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
|
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
|
|
|
|
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
|
|
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
|
|
return self._tokens_to_img(out, gh, gw)
|
|
|
|
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
|
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
|
|
|
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
|
bs, c, gh, gw = x.shape
|
|
|
|
timesteps = 1.0 - timesteps
|
|
|
|
# unconditional pass
|
|
if context is None:
|
|
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
|
|
|
|
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)
|