ComfyUI/comfy/ldm/nucleus/model.py
2026-04-18 22:16:23 -04:00

1043 lines
41 KiB
Python

# Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_nucleusmoe_image.py
# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved.
# Apache 2.0 License
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ldm.common_dit
import comfy.ops
import comfy.patcher_extension
from comfy.quant_ops import QUANT_ALGOS, QuantizedTensor, get_layout_class
from comfy.ldm.flux.math import apply_rope1
class NucleusMoETimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
in_channels=embedding_dim,
time_embed_dim=4 * embedding_dim,
out_dim=embedding_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.norm = operations.RMSNorm(embedding_dim, eps=1e-6, dtype=dtype, device=device)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
return self.norm(timesteps_emb)
class NucleusMoEEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: list, scale_rope=False, dtype=None, device=None, operations=None):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.scale_rope = scale_rope
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.register_buffer(
"pos_freqs",
torch.cat(
[
self._rope_params(pos_index, self.axes_dim[0], self.theta),
self._rope_params(pos_index, self.axes_dim[1], self.theta),
self._rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.register_buffer(
"neg_freqs",
torch.cat(
[
self._rope_params(neg_index, self.axes_dim[0], self.theta),
self._rope_params(neg_index, self.axes_dim[1], self.theta),
self._rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
@staticmethod
def _rope_params(index, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
index.float(),
1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def _compute_video_freqs(self, frame, height, width, idx=0, device=None):
pos_freqs = self.pos_freqs.to(device)
neg_freqs = self.neg_freqs.to(device)
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(frame * height * width, -1)
return freqs.clone().contiguous()
def forward(self, video_fhw, device=None, max_txt_seq_len=None):
if max_txt_seq_len is None:
raise ValueError("max_txt_seq_len must be provided")
if isinstance(video_fhw, list) and len(video_fhw) > 0:
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = None
max_txt_seq_len_int = int(max_txt_seq_len)
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index_val = max(height // 2, width // 2)
else:
max_vid_index_val = max(height, width)
if max_vid_index is None or max_vid_index_val > max_vid_index:
max_vid_index = max_vid_index_val
if max_vid_index is None:
raise ValueError("video_fhw must contain at least one image shape")
end_index = max_vid_index + max_txt_seq_len_int
if end_index > self.pos_freqs.shape[0]:
raise ValueError(
f"Nucleus RoPE requires {end_index} positions, "
f"but only {self.pos_freqs.shape[0]} are available."
)
txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
def _apply_rotary_emb_nucleus(x, freqs_cis):
"""Apply rotary embeddings using complex multiplication.
x: (B, S, H, D) tensor
freqs_cis: (S, D/2) complex tensor
"""
if x.shape[1] == 0:
return x
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1) # (S, 1, D/2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(-2)
return x_out.type_as(x)
class NucleusMoEAttention(nn.Module):
"""Attention with GQA and text KV caching for Nucleus-Image.
Image queries attend to concatenated image+text KV (cross-attention style).
"""
def __init__(
self,
query_dim: int,
dim_head: int = 128,
heads: int = 16,
num_kv_heads: int = 4,
eps: float = 1e-5,
bias: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.heads = heads
self.dim_head = dim_head
self.inner_dim = heads * dim_head
self.inner_kv_dim = num_kv_heads * dim_head
self.num_kv_heads = num_kv_heads
self.num_kv_groups = heads // num_kv_heads
# Image projections
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# Text projections
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
# QK norms
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
# Output
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, query_dim, bias=False, dtype=dtype, device=device),
])
def forward(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
image_rotary_emb=None,
cached_txt_key=None,
cached_txt_value=None,
transformer_options={},
):
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
# Image projections -> (B, S, H, D_head)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1)
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.num_kv_heads, -1)
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.num_kv_heads, -1)
# Normalize
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
# Apply RoPE to image Q/K
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs)
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs)
# Text KV
if cached_txt_key is not None and cached_txt_value is not None:
txt_key, txt_value = cached_txt_key, cached_txt_value
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)
elif encoder_hidden_states is not None:
seq_txt = encoder_hidden_states.shape[1]
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.num_kv_heads, -1)
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.num_kv_heads, -1)
txt_key = self.norm_added_k(txt_key)
if image_rotary_emb is not None:
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs)
joint_key = torch.cat([img_key, txt_key], dim=1)
joint_value = torch.cat([img_value, txt_value], dim=1)
else:
joint_key = img_key
joint_value = img_value
# GQA: repeat KV heads to match query heads
if self.num_kv_groups > 1:
joint_key = joint_key.repeat_interleave(self.num_kv_groups, dim=2)
joint_value = joint_value.repeat_interleave(self.num_kv_groups, dim=2)
# Reshape for attention: (B, H, S, D)
img_query = img_query.transpose(1, 2)
joint_key = joint_key.transpose(1, 2)
joint_value = joint_value.transpose(1, 2)
# Build attention mask
if attention_mask is not None:
seq_txt = attention_mask.shape[-1]
attn_mask = torch.zeros(
(batch_size, 1, seq_txt + seq_img),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
attn_mask[:, 0, seq_img:] = attention_mask
else:
attn_mask = None
# Attention
hidden_states = optimized_attention_masked(
img_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True,
)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
class SwiGLUExperts(nn.Module):
"""SwiGLU feed-forward experts for MoE.
HF checkpoints store packed expert weights and can use grouped GEMM. Split
per-expert Linear modules are still supported for locally quantized files.
"""
def __init__(
self,
hidden_size: int,
moe_intermediate_dim: int,
num_experts: int,
use_grouped_mm: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_experts = num_experts
self.moe_intermediate_dim = moe_intermediate_dim
self.hidden_size = hidden_size
self.use_grouped_mm = use_grouped_mm
self._grouped_mm_failed = False
self._dtype = dtype
self._device = device
self._operations = operations
def _has_packed_experts(self):
return getattr(self, "weight", None) is not None and getattr(self, "bias", None) is not None
def _register_packed_experts(self, gate_up, down):
self.weight = nn.Parameter(gate_up, requires_grad=False)
self.bias = nn.Parameter(down, requires_grad=False)
self.comfy_cast_weights = True
self.weight_function = []
self.bias_function = []
def _pop_quant_conf(self, state_dict, key):
quant_conf = state_dict.pop(key, None)
if quant_conf is None:
return None
return json.loads(quant_conf.cpu().numpy().tobytes())
def _load_packed_quant_tensor(self, state_dict, key, tensor, quant_conf):
if quant_conf is None:
return tensor
quant_format = quant_conf.get("format", None)
if quant_format is None:
raise ValueError(f"Missing quantization format for Nucleus packed expert tensor {key}")
if quant_format not in QUANT_ALGOS:
raise ValueError(f"Unsupported quantization format {quant_format} for Nucleus packed expert tensor {key}")
qconfig = QUANT_ALGOS[quant_format]
layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(layout_type)
scale = state_dict.pop(f"{key}_scale", None)
if scale is None:
scale = state_dict.pop(f"{key}.weight_scale", None)
if scale is None:
raise ValueError(f"Missing quantization scale for Nucleus packed expert tensor {key}")
orig_dtype = self._dtype or torch.bfloat16
params = layout_cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=tuple(tensor.shape),
)
return QuantizedTensor(tensor.to(dtype=qconfig["storage_t"]), layout_type, params)
def _dequantize_for_expert_mm(self, tensor, dtype):
if isinstance(tensor, QuantizedTensor):
tensor = tensor.dequantize()
if tensor.dtype != dtype:
tensor = tensor.to(dtype=dtype)
return tensor
def _build_split_experts(self):
if hasattr(self, "gate_up_projs"):
return
operations = self._operations
self.gate_up_projs = nn.ModuleList([
operations.Linear(self.hidden_size, 2 * self.moe_intermediate_dim, bias=False, dtype=self._dtype, device=self._device)
for _ in range(self.num_experts)
])
self.down_projs = nn.ModuleList([
operations.Linear(self.moe_intermediate_dim, self.hidden_size, bias=False, dtype=self._dtype, device=self._device)
for _ in range(self.num_experts)
])
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
gate_up_key = f"{prefix}gate_up_proj"
down_key = f"{prefix}down_proj"
has_packed = gate_up_key in state_dict and down_key in state_dict
internal_weight_key = f"{prefix}weight"
internal_bias_key = f"{prefix}bias"
has_internal_packed = internal_weight_key in state_dict and internal_bias_key in state_dict
quant_conf_key = f"{prefix}comfy_quant"
packed_quant_conf = self._pop_quant_conf(state_dict, quant_conf_key)
split_prefixes = (f"{prefix}gate_up_projs.", f"{prefix}down_projs.")
has_split = any(k.startswith(split_prefixes) for k in state_dict)
if has_packed or has_internal_packed:
packed_gate_up_key = gate_up_key if has_packed else internal_weight_key
packed_down_key = down_key if has_packed else internal_bias_key
gate_up = state_dict.pop(packed_gate_up_key)
down = state_dict.pop(packed_down_key)
gate_up = self._load_packed_quant_tensor(state_dict, packed_gate_up_key, gate_up, packed_quant_conf)
down = self._load_packed_quant_tensor(state_dict, packed_down_key, down, packed_quant_conf)
self._register_packed_experts(gate_up, down)
elif has_split:
self._build_split_experts()
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if has_packed or has_internal_packed:
for key in (gate_up_key, down_key, internal_weight_key, internal_bias_key, quant_conf_key):
if key in missing_keys:
missing_keys.remove(key)
def _run_experts_split(self, x, num_tokens_per_expert):
output = torch.zeros_like(x)
cumsum = 0
tokens_list = num_tokens_per_expert.tolist()
for i in range(self.num_experts):
n_tokens = int(tokens_list[i])
if n_tokens == 0:
continue
expert_input = x[cumsum : cumsum + n_tokens]
cumsum += n_tokens
gate_up = self.gate_up_projs[i](expert_input)
gate, up = gate_up.chunk(2, dim=-1)
mid = F.silu(gate) * up
expert_output = self.down_projs[i](mid)
output[cumsum - n_tokens : cumsum] = expert_output
return output
def _run_experts_packed_for_loop(self, x, num_tokens_per_expert, gate_up_proj, down_proj):
tokens_list = num_tokens_per_expert.tolist()
num_real_tokens = sum(tokens_list)
num_padding = x.shape[0] - num_real_tokens
x_per_expert = torch.split(x[:num_real_tokens], split_size_or_sections=tokens_list, dim=0)
expert_outputs = []
for expert_idx, x_expert in enumerate(x_per_expert):
if x_expert.shape[0] == 0:
continue
gate_up = torch.matmul(x_expert, gate_up_proj[expert_idx])
gate, up = gate_up.chunk(2, dim=-1)
expert_outputs.append(torch.matmul(F.silu(gate) * up, down_proj[expert_idx]))
if len(expert_outputs) > 0:
out = torch.cat(expert_outputs, dim=0)
else:
out = x.new_zeros((0, self.hidden_size))
if num_padding > 0:
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
return out
def _can_use_grouped_mm(self, x, gate_up_proj):
return (
self.use_grouped_mm
and not self._grouped_mm_failed
and hasattr(F, "grouped_mm")
and x.is_cuda
and x.dtype in (torch.float16, torch.bfloat16, torch.float32)
and gate_up_proj.dtype in (torch.float16, torch.bfloat16, torch.float32)
)
def _run_experts_grouped_mm(self, x, num_tokens_per_expert, gate_up_proj, down_proj):
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
gate_up = F.grouped_mm(x, gate_up_proj, offs=offsets)
gate, up = gate_up.chunk(2, dim=-1)
out = F.grouped_mm(F.silu(gate) * up, down_proj, offs=offsets)
return out.type_as(x)
def forward(self, x, num_tokens_per_expert):
"""Run experts on pre-permuted tokens.
x: (total_tokens, hidden_size)
num_tokens_per_expert: (num_experts,)
"""
if self._has_packed_experts():
gate_up_proj, down_proj, offload_stream = comfy.ops.cast_bias_weight(self, x, offloadable=True, compute_dtype=x.dtype)
try:
gate_up_proj = self._dequantize_for_expert_mm(gate_up_proj, x.dtype)
down_proj = self._dequantize_for_expert_mm(down_proj, x.dtype)
if self._can_use_grouped_mm(x, gate_up_proj):
try:
return self._run_experts_grouped_mm(x, num_tokens_per_expert, gate_up_proj, down_proj)
except RuntimeError:
self._grouped_mm_failed = True
return self._run_experts_packed_for_loop(x, num_tokens_per_expert, gate_up_proj, down_proj)
finally:
comfy.ops.uncast_bias_weight(self, gate_up_proj, down_proj, offload_stream)
if hasattr(self, "gate_up_projs"):
return self._run_experts_split(x, num_tokens_per_expert)
raise RuntimeError("Nucleus MoE experts have not loaded packed or split weights.")
class NucleusMoELayer(nn.Module):
"""Expert-Choice routing with shared expert."""
def __init__(
self,
hidden_size: int,
num_experts: int,
moe_intermediate_dim: int,
capacity_factor: float = 2.0,
use_sigmoid: bool = False,
route_scale: float = 2.5,
use_grouped_mm: bool = False,
timestep_dim: int = None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.use_sigmoid = use_sigmoid
self.route_scale = route_scale
if timestep_dim is None:
timestep_dim = hidden_size
# Router: takes concatenated [timestep; hidden_states]
self.gate = operations.Linear(
hidden_size + timestep_dim, num_experts, bias=False,
dtype=dtype, device=device,
)
# Routed experts
self.experts = SwiGLUExperts(
hidden_size=hidden_size,
moe_intermediate_dim=moe_intermediate_dim,
num_experts=num_experts,
use_grouped_mm=use_grouped_mm,
dtype=dtype,
device=device,
operations=operations,
)
# Shared expert (processes ALL tokens) - FeedForward structure for checkpoint compat
self.shared_expert = FeedForward(hidden_size, dim_out=hidden_size, inner_dim=moe_intermediate_dim, dtype=dtype, device=device, operations=operations)
def forward(self, hidden_states, hidden_states_unmodulated, timestep=None):
"""
Expert-Choice routing.
hidden_states: (B, S, D)
hidden_states_unmodulated: (B, S, D)
timestep: (B, D)
"""
B, S, D = hidden_states.shape
if timestep is not None:
timestep_expanded = timestep.unsqueeze(1).expand(-1, S, -1)
router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1)
else:
router_input = hidden_states_unmodulated
# Compute routing scores
logits = self.gate(router_input) # (B, S, num_experts)
if self.use_sigmoid:
scores = torch.sigmoid(logits.float()).to(logits.dtype)
else:
scores = F.softmax(logits.float(), dim=-1).to(logits.dtype)
# Expert-Choice: top-C selection per expert
capacity = max(1, math.ceil(self.capacity_factor * S / self.num_experts))
affinity = scores.transpose(1, 2) # (B, num_experts, S)
topk = torch.topk(affinity, k=capacity, dim=-1)
top_indices = topk.indices
gating = affinity.gather(dim=-1, index=top_indices)
batch_offsets = torch.arange(B, device=hidden_states.device, dtype=torch.long).view(B, 1, 1) * S
global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
token_score_sums = torch.zeros(B * S, device=hidden_states.device, dtype=gating_flat.dtype)
token_score_sums.scatter_add_(0, global_token_indices, gating_flat)
gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12)
gating_flat = gating_flat * self.route_scale
# Get hidden states for selected tokens
hidden_flat = hidden_states.reshape(-1, D)
selected_hidden = hidden_flat[global_token_indices] # (total_selected, D)
# Run experts
tokens_per_expert = B * capacity
num_tokens_per_expert = torch.full(
(self.num_experts,),
tokens_per_expert,
dtype=torch.long,
device=hidden_states.device,
)
expert_output = self.experts(selected_hidden, num_tokens_per_expert)
# Apply routing scores
expert_output = (expert_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype)
# Scatter back
output = self.shared_expert(hidden_states).reshape(B * S, D)
scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, D)
output = output.scatter_add(dim=0, index=scatter_idx, src=expert_output)
output = output.view(B, S, D)
return output
class SwiGLUProj(nn.Module):
"""SwiGLU projection layer with .proj attribute for checkpoint key compatibility."""
def __init__(self, dim, inner_dim, bias=False, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim, inner_dim * 2, bias=bias, dtype=dtype, device=device)
def forward(self, x):
gate_up = self.proj(x)
hidden_states, gate = gate_up.chunk(2, dim=-1)
return hidden_states * F.silu(gate)
class FeedForward(nn.Module):
"""Dense SwiGLU feed-forward for non-MoE blocks.
Uses net ModuleList to match checkpoint keys: net.0.proj.weight, net.2.weight
"""
def __init__(self, dim, dim_out=None, inner_dim=None, mult=4, dtype=None, device=None, operations=None):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([
SwiGLUProj(dim, inner_dim, bias=False, dtype=dtype, device=device, operations=operations),
nn.Identity(), # dropout placeholder
operations.Linear(inner_dim, dim_out, bias=False, dtype=dtype, device=device),
])
def forward(self, hidden_states):
return self.net[2](self.net[0](hidden_states))
class NucleusMoEImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
num_kv_heads: int,
joint_attention_dim: int = 4096,
is_moe: bool = False,
num_experts: int = 64,
moe_intermediate_dim: int = 1344,
capacity_factor: float = 2.0,
use_sigmoid: bool = False,
route_scale: float = 2.5,
use_grouped_mm: bool = False,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.is_moe = is_moe
# Modulation: produces 4 params (scale1, gate1, scale2, gate2)
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 4 * dim, bias=True, dtype=dtype, device=device),
)
# Norms
self.pre_attn_norm = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
self.pre_mlp_norm = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
# Attention
self.attn = NucleusMoEAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
num_kv_heads=num_kv_heads,
eps=eps,
bias=False,
dtype=dtype,
device=device,
operations=operations,
)
# Text projection
self.encoder_proj = operations.Linear(joint_attention_dim, dim, bias=True, dtype=dtype, device=device)
# FFN / MoE
if is_moe:
self.img_mlp = NucleusMoELayer(
hidden_size=dim,
num_experts=num_experts,
moe_intermediate_dim=moe_intermediate_dim,
capacity_factor=capacity_factor,
use_sigmoid=use_sigmoid,
route_scale=route_scale,
use_grouped_mm=use_grouped_mm,
dtype=dtype,
device=device,
operations=operations,
)
else:
# Dense FFN inner_dim = 4 * moe_intermediate_dim (matching original model)
dense_inner_dim = moe_intermediate_dim * 4
self.img_mlp = FeedForward(
dim=dim, dim_out=dim, inner_dim=dense_inner_dim,
dtype=dtype, device=device, operations=operations,
)
def forward(
self,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb=None,
attention_kwargs=None,
transformer_options={},
):
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
gate1 = gate1.clamp(min=-2.0, max=2.0)
gate2 = gate2.clamp(min=-2.0, max=2.0)
attn_kwargs = attention_kwargs or {}
# Text projection (or use cached KV)
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)
# Attention
img_normed = self.pre_attn_norm(hidden_states)
img_modulated = img_normed * (1 + scale1)
img_attn_output = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=context,
attention_mask=attn_kwargs.get("attention_mask"),
image_rotary_emb=image_rotary_emb,
cached_txt_key=attn_kwargs.get("cached_txt_key"),
cached_txt_value=attn_kwargs.get("cached_txt_value"),
transformer_options=transformer_options,
)
hidden_states = hidden_states + gate1.tanh() * img_attn_output
# FFN / MoE
img_normed2 = self.pre_mlp_norm(hidden_states)
img_modulated2 = img_normed2 * (1 + scale2)
if self.is_moe:
img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb)
else:
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + gate2.tanh() * img_mlp_output
if hidden_states.dtype == torch.float16:
fp16_finfo = torch.finfo(torch.float16)
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)
return hidden_states
class AdaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim, conditioning_embedding_dim, elementwise_affine=False, norm_eps=1e-6, dtype=None, device=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device)
self.norm = operations.LayerNorm(embedding_dim, norm_eps, elementwise_affine, dtype=dtype, device=device)
def forward(self, x, temb):
emb = self.linear(self.silu(temb).to(x.dtype))
scale, shift = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None, :]) + shift[:, None, :]
return x
class NucleusMoEImageTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 64,
out_channels: int = 16,
num_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 16,
num_key_value_heads: int = 4,
joint_attention_dim: int = 4096,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
rope_theta: int = 10000,
scale_rope: bool = True,
dense_moe_strategy: str = "leave_first_three_blocks_dense",
num_experts: int = 64,
moe_intermediate_dim: int = 1344,
capacity_factors: list = None,
use_sigmoid: bool = False,
route_scale: float = 2.5,
use_grouped_mm: bool = False,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.dtype = dtype
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
# Determine which layers are MoE
if capacity_factors is not None:
moe_layers = [cf > 0 for cf in capacity_factors]
else:
moe_layers = [self._is_moe_layer(dense_moe_strategy, i, num_layers) for i in range(num_layers)]
self.moe_layers = moe_layers
# RoPE
self.pos_embed = NucleusMoEEmbedRope(
theta=rope_theta,
axes_dim=list(axes_dims_rope),
scale_rope=scale_rope,
dtype=dtype,
device=device,
operations=operations,
)
# Timestep embedding
self.time_text_embed = NucleusMoETimestepProjEmbeddings(
embedding_dim=self.inner_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Input projections
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels * patch_size * patch_size, self.inner_dim, dtype=dtype, device=device)
# Transformer blocks
self.transformer_blocks = nn.ModuleList([
NucleusMoEImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_kv_heads=num_key_value_heads,
joint_attention_dim=joint_attention_dim,
is_moe=moe_layers[i],
num_experts=num_experts,
moe_intermediate_dim=moe_intermediate_dim,
capacity_factor=capacity_factors[i] if capacity_factors is not None else (2.0 if moe_layers[i] else 0.0),
use_sigmoid=use_sigmoid,
route_scale=route_scale,
use_grouped_mm=use_grouped_mm,
dtype=dtype,
device=device,
operations=operations,
)
for i in range(num_layers)
])
# Output
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, norm_eps=1e-6, dtype=dtype, device=device, operations=operations)
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False, dtype=dtype, device=device)
@staticmethod
def _is_moe_layer(strategy, layer_idx, num_layers):
if strategy == "leave_first_three_and_last_block_dense":
return layer_idx >= 3 and layer_idx < num_layers - 1
elif strategy == "leave_first_three_blocks_dense":
return layer_idx >= 3
elif strategy == "leave_first_block_dense":
return layer_idx >= 1
elif strategy == "all_moe":
return True
elif strategy == "all_dense":
return False
return True
@staticmethod
def _normalize_attention_mask(attention_mask, dtype):
if attention_mask is None:
return None
if attention_mask.ndim > 2:
attention_mask = attention_mask.reshape(attention_mask.shape[0], -1)
if not torch.is_floating_point(attention_mask):
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
if torch.all((attention_mask == 0) | (attention_mask == 1)):
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
return attention_mask.to(dtype)
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(
orig_shape[0], orig_shape[1],
orig_shape[-3], orig_shape[-2] // 2, 2,
orig_shape[-1] // 2, 2,
)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(
orig_shape[0],
orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2),
orig_shape[1] * 4,
)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
return hidden_states, (t_len, h_len, w_len), orig_shape
def forward(self, x, timestep, context, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x,
timesteps,
context,
attention_mask=None,
transformer_options={},
control=None,
**kwargs,
):
encoder_hidden_states = context
encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype)
block_attention_kwargs = {}
if encoder_hidden_states_mask is not None:
block_attention_kwargs["attention_mask"] = encoder_hidden_states_mask
hidden_states, (t_len, h_len, w_len), orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
# Compute RoPE
img_freqs, txt_freqs = self.pos_embed(
video_fhw=[(t_len, h_len, w_len)],
device=x.device,
max_txt_seq_len=encoder_hidden_states.shape[1],
)
image_rotary_emb = (img_freqs, txt_freqs)
# Project inputs
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
temb = self.time_text_embed(timesteps, hidden_states)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
image_rotary_emb=args["pe"],
attention_kwargs=block_attention_kwargs,
transformer_options=args["transformer_options"],
)
return {"img": out, "txt": args["txt"]}
out = blocks_replace[("single_block", i)](
{"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options},
{"original_block": block_wrap},
)
hidden_states = out["img"]
encoder_hidden_states = out.get("txt", encoder_hidden_states)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=block_attention_kwargs,
transformer_options=transformer_options,
)
if "single_block" in patches:
for p in patches["single_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out.get("txt", encoder_hidden_states)
if control is not None:
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
# Output
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# Unpack
hidden_states = hidden_states[:, :num_embeds].view(
orig_shape[0], orig_shape[-3],
orig_shape[-2] // 2, orig_shape[-1] // 2,
orig_shape[1], 2, 2,
)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
# Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]