ComfyUI/comfy/ldm/ideogram4/model.py
2026-06-03 08:41:44 -07:00

298 lines
14 KiB
Python

"""
The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model
consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims)
packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE.
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.lumina.model import FeedForward
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis
# Per-token role indicators
SEQUENCE_PADDING_INDICATOR = -1
OUTPUT_IMAGE_INDICATOR = 2
LLM_TOKEN_INDICATOR = 3
# Image grid coordinates are offset so they never collide with text positions
IMAGE_POSITION_OFFSET = 65536
class Ideogram4Attention(nn.Module):
def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.hidden_size = hidden_size
self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, x, attn_mask, freqs_cis, transformer_options={}):
batch_size, seq_len, _ = x.shape
qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = self.norm_q(q)
k = self.norm_k(k)
# (B, heads, L, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
return self.o(out)
class Ideogram4TransformerBlock(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None):
super().__init__()
self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
self.feed_forward = FeedForward(
dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None,
operation_settings={"operations": operations, "dtype": dtype, "device": device},
)
self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}):
mod = self.adaln_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1)
gate_msa = torch.tanh(gate_msa)
gate_mlp = torch.tanh(gate_mlp)
scale_msa = 1.0 + scale_msa
scale_mlp = 1.0 + scale_mlp
attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
return x
def _sinusoidal_embedding(t, dim, scale=1e4):
t = t.to(torch.float32)
half = dim // 2
freq = math.log(scale) / (half - 1)
freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq)
emb = t.unsqueeze(-1) * freq
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
class Ideogram4EmbedScalar(nn.Module):
def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None):
super().__init__()
self.dim = dim
self.range_min, self.range_max = input_range
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = x.to(torch.float32)
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
emb = _sinusoidal_embedding(scaled, self.dim)
emb = emb.to(self.mlp_in.weight.dtype)
emb = F.silu(self.mlp_in(emb))
return self.mlp_out(emb)
class Ideogram4FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, x, c):
scale = 1.0 + self.adaln_modulation(F.silu(c))
return self.linear(self.norm_final(x) * scale)
class Ideogram4Transformer(nn.Module):
"""A single Ideogram 4 backbone operating on a packed token sequence."""
def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim,
in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps,
dtype=None, device=None, operations=None):
super().__init__()
self.head_dim = emb_dim // num_heads
self.rope_theta = rope_theta
self.mrope_section = tuple(mrope_section)
self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device)
self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device)
self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device)
self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations)
self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device)
self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device)
self.layers = nn.ModuleList([
Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim,
dtype=dtype, device=device, operations=operations)
for _ in range(num_layers)
])
self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations)
def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}):
indicator = indicator.to(torch.long)
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1)
x = x * output_image_mask
h = self.input_proj(x) * output_image_mask
t_cond = self.t_embedding(t)
if t.dim() == 1:
t_cond = t_cond.unsqueeze(1)
adaln_input = F.silu(self.adaln_proj(t_cond))
# h is zero on the text rows (content lives only on image rows), add writes the text features in place
if llm_features is not None:
L_text = llm_features.shape[1]
text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1)
llm = self.llm_cond_norm(llm_features * text_mask)
llm = self.llm_cond_proj(llm) * text_mask
h[:, :L_text] = h[:, :L_text] + llm
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
freqs_cis = precompute_freqs_cis(
self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta,
rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device,
)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max)
for layer in self.layers:
h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options)
return self.final_layer(h, adaln_input)
class Ideogram4Transformer2DModel(Ideogram4Transformer):
"""Ideogram 4 single-stream DiT.
Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``.
"""
def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288,
adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5,
dtype=None, device=None, operations=None, **kwargs):
emb_dim = num_attention_heads * attention_head_dim
super().__init__(
emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads,
intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels,
llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section,
norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
# 128-dim token = patch (2x2) * ae_channels (32).
self.patch_size = 2
self.ae_channels = in_channels // (self.patch_size * self.patch_size)
def _img_to_tokens(self, x):
B, C, gh, gw = x.shape
x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw)
x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c)
return x.reshape(B, gh * gw, C)
def _tokens_to_img(self, tokens, gh, gw):
B = tokens.shape[0]
C = tokens.shape[-1]
x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels)
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
return x.reshape(B, C, gh, gw)
def _image_position_ids(self, gh, gw, device):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
t_idx = torch.zeros_like(h_idx)
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
latent_dim = img_tokens.shape[-1]
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, L_text:] = img_tokens
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
img_pos = self._image_position_ids(gh, gw, device)
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
indicator = torch.empty(B, L, dtype=torch.long, device=device)
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR
attn_mask = None
if attn_mask_chunk is not None:
segment_ids = torch.ones(B, L, dtype=torch.long, device=device)
pad = (attn_mask_chunk == 0)
segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR
indicator[:, :L_text][pad] = 0
# Block-diagonal mask from segment ids: (B, 1, L, L), True = attend.
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
transformer_options=transformer_options)
return self._tokens_to_img(out[:, L_text:], gh, gw)
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
L_img = img_tokens.shape[1]
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out, gh, gw)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
bs, c, gh, gw = x.shape
timesteps = 1.0 - timesteps
# unconditional pass
if context is None:
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)