ComfyUI/comfy/ldm/pixeldit/model.py
2026-05-25 19:51:05 +03:00

273 lines
12 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.groups = groups
self.head_dim = hidden_size // groups
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False,
dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation_txt = nn.Sequential(
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
use_pixel_abs_pos=True,
image_model=None,
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
self.in_channels = int(in_channels)
self.out_channels = int(in_channels)
self.hidden_size = int(hidden_size)
self.num_groups = int(num_groups)
self.patch_depth = int(patch_depth)
self.pixel_depth = int(pixel_depth)
self.patch_size = int(patch_size)
self.pixel_hidden_size = int(pixel_hidden_size)
self.pixel_attn_hidden_size = int(pixel_attn_hidden_size)
self.pixel_num_groups = int(pixel_num_groups)
self.txt_embed_dim = int(txt_embed_dim)
self.txt_max_length = int(txt_max_length)
self.use_text_rope = bool(use_text_rope)
self.text_rope_theta = float(text_rope_theta)
self.use_pixel_abs_pos = bool(use_pixel_abs_pos)
self.pixel_embedder = PixelTokenEmbedder(
self.in_channels, self.pixel_hidden_size, use_pixel_abs_pos=self.use_pixel_abs_pos,
dtype=dtype, device=device, operations=operations,
)
self.s_embedder = PatchTokenEmbedder(
self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True,
dtype=dtype, device=device, operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(
self.txt_embed_dim, self.hidden_size, bias=True, norm_layer=True,
dtype=dtype, device=device, operations=operations,
)
self.y_pos_embedding = nn.Parameter(
torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)
)
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
mlp_ratio=4.0,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels,
dtype=dtype, device=device, operations=operations)
self._patch_pos_cache = {}
self._text_pos_cache = {}
def _fetch_patch_pos(self, height, width, device, dtype):
key = (height, width)
pos = self._patch_pos_cache.get(key)
if pos is None:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
self._patch_pos_cache[key] = pos
return pos.to(device=device, dtype=dtype)
def _fetch_text_pos(self, length, device, dtype):
pos = self._text_pos_cache.get(length)
if pos is None:
pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0)
self._text_pos_cache[length] = pos
return pos.to(device=device, dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype)
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for blk in self.patch_blocks:
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
x_pixels = x_pixels.view(B, C_out * P2, L)
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)