mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
1043 lines
41 KiB
Python
1043 lines
41 KiB
Python
# Based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_nucleusmoe_image.py
|
|
# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved.
|
|
# Apache 2.0 License
|
|
import json
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple
|
|
from einops import repeat
|
|
|
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
|
import comfy.ldm.common_dit
|
|
import comfy.ops
|
|
import comfy.patcher_extension
|
|
from comfy.quant_ops import QUANT_ALGOS, QuantizedTensor, get_layout_class
|
|
from comfy.ldm.flux.math import apply_rope1
|
|
|
|
|
|
class NucleusMoETimestepProjEmbeddings(nn.Module):
|
|
def __init__(self, embedding_dim, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.time_proj = Timesteps(num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
|
self.timestep_embedder = TimestepEmbedding(
|
|
in_channels=embedding_dim,
|
|
time_embed_dim=4 * embedding_dim,
|
|
out_dim=embedding_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.norm = operations.RMSNorm(embedding_dim, eps=1e-6, dtype=dtype, device=device)
|
|
|
|
def forward(self, timestep, hidden_states):
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
|
return self.norm(timesteps_emb)
|
|
|
|
|
|
class NucleusMoEEmbedRope(nn.Module):
|
|
def __init__(self, theta: int, axes_dim: list, scale_rope=False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
self.scale_rope = scale_rope
|
|
|
|
pos_index = torch.arange(4096)
|
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
|
|
|
self.register_buffer(
|
|
"pos_freqs",
|
|
torch.cat(
|
|
[
|
|
self._rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
self._rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
self._rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
],
|
|
dim=1,
|
|
),
|
|
)
|
|
self.register_buffer(
|
|
"neg_freqs",
|
|
torch.cat(
|
|
[
|
|
self._rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
self._rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
self._rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
],
|
|
dim=1,
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _rope_params(index, dim, theta=10000):
|
|
assert dim % 2 == 0
|
|
freqs = torch.outer(
|
|
index.float(),
|
|
1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)),
|
|
)
|
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
return freqs
|
|
|
|
def _compute_video_freqs(self, frame, height, width, idx=0, device=None):
|
|
pos_freqs = self.pos_freqs.to(device)
|
|
neg_freqs = self.neg_freqs.to(device)
|
|
|
|
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
|
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
|
|
if self.scale_rope:
|
|
freqs_height = torch.cat(
|
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
)
|
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = torch.cat(
|
|
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
|
|
)
|
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
else:
|
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
|
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(frame * height * width, -1)
|
|
return freqs.clone().contiguous()
|
|
|
|
def forward(self, video_fhw, device=None, max_txt_seq_len=None):
|
|
if max_txt_seq_len is None:
|
|
raise ValueError("max_txt_seq_len must be provided")
|
|
|
|
if isinstance(video_fhw, list) and len(video_fhw) > 0:
|
|
video_fhw = video_fhw[0]
|
|
if not isinstance(video_fhw, list):
|
|
video_fhw = [video_fhw]
|
|
|
|
vid_freqs = []
|
|
max_vid_index = None
|
|
max_txt_seq_len_int = int(max_txt_seq_len)
|
|
for idx, fhw in enumerate(video_fhw):
|
|
frame, height, width = fhw
|
|
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
|
vid_freqs.append(video_freq)
|
|
|
|
if self.scale_rope:
|
|
max_vid_index_val = max(height // 2, width // 2)
|
|
else:
|
|
max_vid_index_val = max(height, width)
|
|
if max_vid_index is None or max_vid_index_val > max_vid_index:
|
|
max_vid_index = max_vid_index_val
|
|
|
|
if max_vid_index is None:
|
|
raise ValueError("video_fhw must contain at least one image shape")
|
|
end_index = max_vid_index + max_txt_seq_len_int
|
|
if end_index > self.pos_freqs.shape[0]:
|
|
raise ValueError(
|
|
f"Nucleus RoPE requires {end_index} positions, "
|
|
f"but only {self.pos_freqs.shape[0]} are available."
|
|
)
|
|
|
|
txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index]
|
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
|
|
return vid_freqs, txt_freqs
|
|
|
|
|
|
def _apply_rotary_emb_nucleus(x, freqs_cis):
|
|
"""Apply rotary embeddings using complex multiplication.
|
|
x: (B, S, H, D) tensor
|
|
freqs_cis: (S, D/2) complex tensor
|
|
"""
|
|
if x.shape[1] == 0:
|
|
return x
|
|
|
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
freqs_cis = freqs_cis.unsqueeze(1) # (S, 1, D/2)
|
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(-2)
|
|
return x_out.type_as(x)
|
|
|
|
|
|
class NucleusMoEAttention(nn.Module):
|
|
"""Attention with GQA and text KV caching for Nucleus-Image.
|
|
Image queries attend to concatenated image+text KV (cross-attention style).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
dim_head: int = 128,
|
|
heads: int = 16,
|
|
num_kv_heads: int = 4,
|
|
eps: float = 1e-5,
|
|
bias: bool = False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
self.inner_dim = heads * dim_head
|
|
self.inner_kv_dim = num_kv_heads * dim_head
|
|
self.num_kv_heads = num_kv_heads
|
|
self.num_kv_groups = heads // num_kv_heads
|
|
|
|
# Image projections
|
|
self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
|
|
self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
|
self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
|
|
|
# Text projections
|
|
self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
|
self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
|
|
|
|
# QK norms
|
|
self.norm_q = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.norm_k = operations.RMSNorm(dim_head, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
|
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
|
|
|
# Output
|
|
self.to_out = nn.ModuleList([
|
|
operations.Linear(self.inner_dim, query_dim, bias=False, dtype=dtype, device=device),
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
image_rotary_emb=None,
|
|
cached_txt_key=None,
|
|
cached_txt_value=None,
|
|
transformer_options={},
|
|
):
|
|
batch_size = hidden_states.shape[0]
|
|
seq_img = hidden_states.shape[1]
|
|
|
|
# Image projections -> (B, S, H, D_head)
|
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1)
|
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.num_kv_heads, -1)
|
|
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.num_kv_heads, -1)
|
|
|
|
# Normalize
|
|
img_query = self.norm_q(img_query)
|
|
img_key = self.norm_k(img_key)
|
|
|
|
# Apply RoPE to image Q/K
|
|
if image_rotary_emb is not None:
|
|
img_freqs, txt_freqs = image_rotary_emb
|
|
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs)
|
|
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs)
|
|
|
|
# Text KV
|
|
if cached_txt_key is not None and cached_txt_value is not None:
|
|
txt_key, txt_value = cached_txt_key, cached_txt_value
|
|
joint_key = torch.cat([img_key, txt_key], dim=1)
|
|
joint_value = torch.cat([img_value, txt_value], dim=1)
|
|
elif encoder_hidden_states is not None:
|
|
seq_txt = encoder_hidden_states.shape[1]
|
|
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.num_kv_heads, -1)
|
|
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.num_kv_heads, -1)
|
|
|
|
txt_key = self.norm_added_k(txt_key)
|
|
|
|
if image_rotary_emb is not None:
|
|
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs)
|
|
|
|
joint_key = torch.cat([img_key, txt_key], dim=1)
|
|
joint_value = torch.cat([img_value, txt_value], dim=1)
|
|
else:
|
|
joint_key = img_key
|
|
joint_value = img_value
|
|
|
|
# GQA: repeat KV heads to match query heads
|
|
if self.num_kv_groups > 1:
|
|
joint_key = joint_key.repeat_interleave(self.num_kv_groups, dim=2)
|
|
joint_value = joint_value.repeat_interleave(self.num_kv_groups, dim=2)
|
|
|
|
# Reshape for attention: (B, H, S, D)
|
|
img_query = img_query.transpose(1, 2)
|
|
joint_key = joint_key.transpose(1, 2)
|
|
joint_value = joint_value.transpose(1, 2)
|
|
|
|
# Build attention mask
|
|
if attention_mask is not None:
|
|
seq_txt = attention_mask.shape[-1]
|
|
attn_mask = torch.zeros(
|
|
(batch_size, 1, seq_txt + seq_img),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
attn_mask[:, 0, seq_img:] = attention_mask
|
|
else:
|
|
attn_mask = None
|
|
|
|
# Attention
|
|
hidden_states = optimized_attention_masked(
|
|
img_query, joint_key, joint_value, self.heads,
|
|
attn_mask, transformer_options=transformer_options,
|
|
skip_reshape=True,
|
|
)
|
|
|
|
hidden_states = self.to_out[0](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class SwiGLUExperts(nn.Module):
|
|
"""SwiGLU feed-forward experts for MoE.
|
|
|
|
HF checkpoints store packed expert weights and can use grouped GEMM. Split
|
|
per-expert Linear modules are still supported for locally quantized files.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
moe_intermediate_dim: int,
|
|
num_experts: int,
|
|
use_grouped_mm: bool = False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.num_experts = num_experts
|
|
self.moe_intermediate_dim = moe_intermediate_dim
|
|
self.hidden_size = hidden_size
|
|
self.use_grouped_mm = use_grouped_mm
|
|
self._grouped_mm_failed = False
|
|
self._dtype = dtype
|
|
self._device = device
|
|
self._operations = operations
|
|
|
|
def _has_packed_experts(self):
|
|
return getattr(self, "weight", None) is not None and getattr(self, "bias", None) is not None
|
|
|
|
def _register_packed_experts(self, gate_up, down):
|
|
self.weight = nn.Parameter(gate_up, requires_grad=False)
|
|
self.bias = nn.Parameter(down, requires_grad=False)
|
|
self.comfy_cast_weights = True
|
|
self.weight_function = []
|
|
self.bias_function = []
|
|
|
|
def _pop_quant_conf(self, state_dict, key):
|
|
quant_conf = state_dict.pop(key, None)
|
|
if quant_conf is None:
|
|
return None
|
|
return json.loads(quant_conf.cpu().numpy().tobytes())
|
|
|
|
def _load_packed_quant_tensor(self, state_dict, key, tensor, quant_conf):
|
|
if quant_conf is None:
|
|
return tensor
|
|
|
|
quant_format = quant_conf.get("format", None)
|
|
if quant_format is None:
|
|
raise ValueError(f"Missing quantization format for Nucleus packed expert tensor {key}")
|
|
if quant_format not in QUANT_ALGOS:
|
|
raise ValueError(f"Unsupported quantization format {quant_format} for Nucleus packed expert tensor {key}")
|
|
|
|
qconfig = QUANT_ALGOS[quant_format]
|
|
layout_type = qconfig["comfy_tensor_layout"]
|
|
layout_cls = get_layout_class(layout_type)
|
|
scale = state_dict.pop(f"{key}_scale", None)
|
|
if scale is None:
|
|
scale = state_dict.pop(f"{key}.weight_scale", None)
|
|
if scale is None:
|
|
raise ValueError(f"Missing quantization scale for Nucleus packed expert tensor {key}")
|
|
|
|
orig_dtype = self._dtype or torch.bfloat16
|
|
params = layout_cls.Params(
|
|
scale=scale,
|
|
orig_dtype=orig_dtype,
|
|
orig_shape=tuple(tensor.shape),
|
|
)
|
|
return QuantizedTensor(tensor.to(dtype=qconfig["storage_t"]), layout_type, params)
|
|
|
|
def _dequantize_for_expert_mm(self, tensor, dtype):
|
|
if isinstance(tensor, QuantizedTensor):
|
|
tensor = tensor.dequantize()
|
|
if tensor.dtype != dtype:
|
|
tensor = tensor.to(dtype=dtype)
|
|
return tensor
|
|
|
|
def _build_split_experts(self):
|
|
if hasattr(self, "gate_up_projs"):
|
|
return
|
|
operations = self._operations
|
|
self.gate_up_projs = nn.ModuleList([
|
|
operations.Linear(self.hidden_size, 2 * self.moe_intermediate_dim, bias=False, dtype=self._dtype, device=self._device)
|
|
for _ in range(self.num_experts)
|
|
])
|
|
self.down_projs = nn.ModuleList([
|
|
operations.Linear(self.moe_intermediate_dim, self.hidden_size, bias=False, dtype=self._dtype, device=self._device)
|
|
for _ in range(self.num_experts)
|
|
])
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
gate_up_key = f"{prefix}gate_up_proj"
|
|
down_key = f"{prefix}down_proj"
|
|
has_packed = gate_up_key in state_dict and down_key in state_dict
|
|
internal_weight_key = f"{prefix}weight"
|
|
internal_bias_key = f"{prefix}bias"
|
|
has_internal_packed = internal_weight_key in state_dict and internal_bias_key in state_dict
|
|
quant_conf_key = f"{prefix}comfy_quant"
|
|
packed_quant_conf = self._pop_quant_conf(state_dict, quant_conf_key)
|
|
split_prefixes = (f"{prefix}gate_up_projs.", f"{prefix}down_projs.")
|
|
has_split = any(k.startswith(split_prefixes) for k in state_dict)
|
|
|
|
if has_packed or has_internal_packed:
|
|
packed_gate_up_key = gate_up_key if has_packed else internal_weight_key
|
|
packed_down_key = down_key if has_packed else internal_bias_key
|
|
gate_up = state_dict.pop(packed_gate_up_key)
|
|
down = state_dict.pop(packed_down_key)
|
|
gate_up = self._load_packed_quant_tensor(state_dict, packed_gate_up_key, gate_up, packed_quant_conf)
|
|
down = self._load_packed_quant_tensor(state_dict, packed_down_key, down, packed_quant_conf)
|
|
self._register_packed_experts(gate_up, down)
|
|
elif has_split:
|
|
self._build_split_experts()
|
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
if has_packed or has_internal_packed:
|
|
for key in (gate_up_key, down_key, internal_weight_key, internal_bias_key, quant_conf_key):
|
|
if key in missing_keys:
|
|
missing_keys.remove(key)
|
|
|
|
def _run_experts_split(self, x, num_tokens_per_expert):
|
|
output = torch.zeros_like(x)
|
|
cumsum = 0
|
|
tokens_list = num_tokens_per_expert.tolist()
|
|
|
|
for i in range(self.num_experts):
|
|
n_tokens = int(tokens_list[i])
|
|
if n_tokens == 0:
|
|
continue
|
|
expert_input = x[cumsum : cumsum + n_tokens]
|
|
cumsum += n_tokens
|
|
|
|
gate_up = self.gate_up_projs[i](expert_input)
|
|
gate, up = gate_up.chunk(2, dim=-1)
|
|
mid = F.silu(gate) * up
|
|
expert_output = self.down_projs[i](mid)
|
|
|
|
output[cumsum - n_tokens : cumsum] = expert_output
|
|
|
|
return output
|
|
|
|
def _run_experts_packed_for_loop(self, x, num_tokens_per_expert, gate_up_proj, down_proj):
|
|
tokens_list = num_tokens_per_expert.tolist()
|
|
num_real_tokens = sum(tokens_list)
|
|
num_padding = x.shape[0] - num_real_tokens
|
|
x_per_expert = torch.split(x[:num_real_tokens], split_size_or_sections=tokens_list, dim=0)
|
|
|
|
expert_outputs = []
|
|
for expert_idx, x_expert in enumerate(x_per_expert):
|
|
if x_expert.shape[0] == 0:
|
|
continue
|
|
gate_up = torch.matmul(x_expert, gate_up_proj[expert_idx])
|
|
gate, up = gate_up.chunk(2, dim=-1)
|
|
expert_outputs.append(torch.matmul(F.silu(gate) * up, down_proj[expert_idx]))
|
|
|
|
if len(expert_outputs) > 0:
|
|
out = torch.cat(expert_outputs, dim=0)
|
|
else:
|
|
out = x.new_zeros((0, self.hidden_size))
|
|
if num_padding > 0:
|
|
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
|
|
return out
|
|
|
|
def _can_use_grouped_mm(self, x, gate_up_proj):
|
|
return (
|
|
self.use_grouped_mm
|
|
and not self._grouped_mm_failed
|
|
and hasattr(F, "grouped_mm")
|
|
and x.is_cuda
|
|
and x.dtype in (torch.float16, torch.bfloat16, torch.float32)
|
|
and gate_up_proj.dtype in (torch.float16, torch.bfloat16, torch.float32)
|
|
)
|
|
|
|
def _run_experts_grouped_mm(self, x, num_tokens_per_expert, gate_up_proj, down_proj):
|
|
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
|
gate_up = F.grouped_mm(x, gate_up_proj, offs=offsets)
|
|
gate, up = gate_up.chunk(2, dim=-1)
|
|
out = F.grouped_mm(F.silu(gate) * up, down_proj, offs=offsets)
|
|
return out.type_as(x)
|
|
|
|
def forward(self, x, num_tokens_per_expert):
|
|
"""Run experts on pre-permuted tokens.
|
|
x: (total_tokens, hidden_size)
|
|
num_tokens_per_expert: (num_experts,)
|
|
"""
|
|
if self._has_packed_experts():
|
|
gate_up_proj, down_proj, offload_stream = comfy.ops.cast_bias_weight(self, x, offloadable=True, compute_dtype=x.dtype)
|
|
try:
|
|
gate_up_proj = self._dequantize_for_expert_mm(gate_up_proj, x.dtype)
|
|
down_proj = self._dequantize_for_expert_mm(down_proj, x.dtype)
|
|
if self._can_use_grouped_mm(x, gate_up_proj):
|
|
try:
|
|
return self._run_experts_grouped_mm(x, num_tokens_per_expert, gate_up_proj, down_proj)
|
|
except RuntimeError:
|
|
self._grouped_mm_failed = True
|
|
return self._run_experts_packed_for_loop(x, num_tokens_per_expert, gate_up_proj, down_proj)
|
|
finally:
|
|
comfy.ops.uncast_bias_weight(self, gate_up_proj, down_proj, offload_stream)
|
|
|
|
if hasattr(self, "gate_up_projs"):
|
|
return self._run_experts_split(x, num_tokens_per_expert)
|
|
|
|
raise RuntimeError("Nucleus MoE experts have not loaded packed or split weights.")
|
|
|
|
|
|
class NucleusMoELayer(nn.Module):
|
|
"""Expert-Choice routing with shared expert."""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
moe_intermediate_dim: int,
|
|
capacity_factor: float = 2.0,
|
|
use_sigmoid: bool = False,
|
|
route_scale: float = 2.5,
|
|
use_grouped_mm: bool = False,
|
|
timestep_dim: int = None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_experts = num_experts
|
|
self.capacity_factor = capacity_factor
|
|
self.use_sigmoid = use_sigmoid
|
|
self.route_scale = route_scale
|
|
|
|
if timestep_dim is None:
|
|
timestep_dim = hidden_size
|
|
|
|
# Router: takes concatenated [timestep; hidden_states]
|
|
self.gate = operations.Linear(
|
|
hidden_size + timestep_dim, num_experts, bias=False,
|
|
dtype=dtype, device=device,
|
|
)
|
|
|
|
# Routed experts
|
|
self.experts = SwiGLUExperts(
|
|
hidden_size=hidden_size,
|
|
moe_intermediate_dim=moe_intermediate_dim,
|
|
num_experts=num_experts,
|
|
use_grouped_mm=use_grouped_mm,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Shared expert (processes ALL tokens) - FeedForward structure for checkpoint compat
|
|
self.shared_expert = FeedForward(hidden_size, dim_out=hidden_size, inner_dim=moe_intermediate_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, hidden_states, hidden_states_unmodulated, timestep=None):
|
|
"""
|
|
Expert-Choice routing.
|
|
hidden_states: (B, S, D)
|
|
hidden_states_unmodulated: (B, S, D)
|
|
timestep: (B, D)
|
|
"""
|
|
B, S, D = hidden_states.shape
|
|
|
|
if timestep is not None:
|
|
timestep_expanded = timestep.unsqueeze(1).expand(-1, S, -1)
|
|
router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1)
|
|
else:
|
|
router_input = hidden_states_unmodulated
|
|
|
|
# Compute routing scores
|
|
logits = self.gate(router_input) # (B, S, num_experts)
|
|
if self.use_sigmoid:
|
|
scores = torch.sigmoid(logits.float()).to(logits.dtype)
|
|
else:
|
|
scores = F.softmax(logits.float(), dim=-1).to(logits.dtype)
|
|
|
|
# Expert-Choice: top-C selection per expert
|
|
capacity = max(1, math.ceil(self.capacity_factor * S / self.num_experts))
|
|
|
|
affinity = scores.transpose(1, 2) # (B, num_experts, S)
|
|
topk = torch.topk(affinity, k=capacity, dim=-1)
|
|
top_indices = topk.indices
|
|
gating = affinity.gather(dim=-1, index=top_indices)
|
|
|
|
batch_offsets = torch.arange(B, device=hidden_states.device, dtype=torch.long).view(B, 1, 1) * S
|
|
global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
|
gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
|
|
|
token_score_sums = torch.zeros(B * S, device=hidden_states.device, dtype=gating_flat.dtype)
|
|
token_score_sums.scatter_add_(0, global_token_indices, gating_flat)
|
|
gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12)
|
|
gating_flat = gating_flat * self.route_scale
|
|
|
|
# Get hidden states for selected tokens
|
|
hidden_flat = hidden_states.reshape(-1, D)
|
|
selected_hidden = hidden_flat[global_token_indices] # (total_selected, D)
|
|
|
|
# Run experts
|
|
tokens_per_expert = B * capacity
|
|
num_tokens_per_expert = torch.full(
|
|
(self.num_experts,),
|
|
tokens_per_expert,
|
|
dtype=torch.long,
|
|
device=hidden_states.device,
|
|
)
|
|
expert_output = self.experts(selected_hidden, num_tokens_per_expert)
|
|
|
|
# Apply routing scores
|
|
expert_output = (expert_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype)
|
|
|
|
# Scatter back
|
|
output = self.shared_expert(hidden_states).reshape(B * S, D)
|
|
scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, D)
|
|
output = output.scatter_add(dim=0, index=scatter_idx, src=expert_output)
|
|
output = output.view(B, S, D)
|
|
|
|
return output
|
|
|
|
|
|
class SwiGLUProj(nn.Module):
|
|
"""SwiGLU projection layer with .proj attribute for checkpoint key compatibility."""
|
|
|
|
def __init__(self, dim, inner_dim, bias=False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.proj = operations.Linear(dim, inner_dim * 2, bias=bias, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
gate_up = self.proj(x)
|
|
hidden_states, gate = gate_up.chunk(2, dim=-1)
|
|
return hidden_states * F.silu(gate)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
"""Dense SwiGLU feed-forward for non-MoE blocks.
|
|
Uses net ModuleList to match checkpoint keys: net.0.proj.weight, net.2.weight
|
|
"""
|
|
|
|
def __init__(self, dim, dim_out=None, inner_dim=None, mult=4, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
if inner_dim is None:
|
|
inner_dim = int(dim * mult)
|
|
dim_out = dim_out if dim_out is not None else dim
|
|
self.net = nn.ModuleList([
|
|
SwiGLUProj(dim, inner_dim, bias=False, dtype=dtype, device=device, operations=operations),
|
|
nn.Identity(), # dropout placeholder
|
|
operations.Linear(inner_dim, dim_out, bias=False, dtype=dtype, device=device),
|
|
])
|
|
|
|
def forward(self, hidden_states):
|
|
return self.net[2](self.net[0](hidden_states))
|
|
|
|
|
|
class NucleusMoEImageTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
num_kv_heads: int,
|
|
joint_attention_dim: int = 4096,
|
|
is_moe: bool = False,
|
|
num_experts: int = 64,
|
|
moe_intermediate_dim: int = 1344,
|
|
capacity_factor: float = 2.0,
|
|
use_sigmoid: bool = False,
|
|
route_scale: float = 2.5,
|
|
use_grouped_mm: bool = False,
|
|
eps: float = 1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.is_moe = is_moe
|
|
|
|
# Modulation: produces 4 params (scale1, gate1, scale2, gate2)
|
|
self.img_mod = nn.Sequential(
|
|
nn.SiLU(),
|
|
operations.Linear(dim, 4 * dim, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
# Norms
|
|
self.pre_attn_norm = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
|
self.pre_mlp_norm = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device)
|
|
|
|
# Attention
|
|
self.attn = NucleusMoEAttention(
|
|
query_dim=dim,
|
|
dim_head=attention_head_dim,
|
|
heads=num_attention_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
eps=eps,
|
|
bias=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Text projection
|
|
self.encoder_proj = operations.Linear(joint_attention_dim, dim, bias=True, dtype=dtype, device=device)
|
|
|
|
# FFN / MoE
|
|
if is_moe:
|
|
self.img_mlp = NucleusMoELayer(
|
|
hidden_size=dim,
|
|
num_experts=num_experts,
|
|
moe_intermediate_dim=moe_intermediate_dim,
|
|
capacity_factor=capacity_factor,
|
|
use_sigmoid=use_sigmoid,
|
|
route_scale=route_scale,
|
|
use_grouped_mm=use_grouped_mm,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
else:
|
|
# Dense FFN inner_dim = 4 * moe_intermediate_dim (matching original model)
|
|
dense_inner_dim = moe_intermediate_dim * 4
|
|
self.img_mlp = FeedForward(
|
|
dim=dim, dim_out=dim, inner_dim=dense_inner_dim,
|
|
dtype=dtype, device=device, operations=operations,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
encoder_hidden_states,
|
|
temb,
|
|
image_rotary_emb=None,
|
|
attention_kwargs=None,
|
|
transformer_options={},
|
|
):
|
|
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
|
|
|
|
gate1 = gate1.clamp(min=-2.0, max=2.0)
|
|
gate2 = gate2.clamp(min=-2.0, max=2.0)
|
|
|
|
attn_kwargs = attention_kwargs or {}
|
|
|
|
# Text projection (or use cached KV)
|
|
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)
|
|
|
|
# Attention
|
|
img_normed = self.pre_attn_norm(hidden_states)
|
|
img_modulated = img_normed * (1 + scale1)
|
|
|
|
img_attn_output = self.attn(
|
|
hidden_states=img_modulated,
|
|
encoder_hidden_states=context,
|
|
attention_mask=attn_kwargs.get("attention_mask"),
|
|
image_rotary_emb=image_rotary_emb,
|
|
cached_txt_key=attn_kwargs.get("cached_txt_key"),
|
|
cached_txt_value=attn_kwargs.get("cached_txt_value"),
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
hidden_states = hidden_states + gate1.tanh() * img_attn_output
|
|
|
|
# FFN / MoE
|
|
img_normed2 = self.pre_mlp_norm(hidden_states)
|
|
img_modulated2 = img_normed2 * (1 + scale2)
|
|
|
|
if self.is_moe:
|
|
img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb)
|
|
else:
|
|
img_mlp_output = self.img_mlp(img_modulated2)
|
|
|
|
hidden_states = hidden_states + gate2.tanh() * img_mlp_output
|
|
|
|
if hidden_states.dtype == torch.float16:
|
|
fp16_finfo = torch.finfo(torch.float16)
|
|
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AdaLayerNormContinuous(nn.Module):
|
|
def __init__(self, embedding_dim, conditioning_embedding_dim, elementwise_affine=False, norm_eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.silu = nn.SiLU()
|
|
self.linear = operations.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device)
|
|
self.norm = operations.LayerNorm(embedding_dim, norm_eps, elementwise_affine, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, temb):
|
|
emb = self.linear(self.silu(temb).to(x.dtype))
|
|
scale, shift = emb.chunk(2, dim=1)
|
|
x = self.norm(x) * (1 + scale[:, None, :]) + shift[:, None, :]
|
|
return x
|
|
|
|
|
|
class NucleusMoEImageTransformer2DModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 2,
|
|
in_channels: int = 64,
|
|
out_channels: int = 16,
|
|
num_layers: int = 32,
|
|
attention_head_dim: int = 128,
|
|
num_attention_heads: int = 16,
|
|
num_key_value_heads: int = 4,
|
|
joint_attention_dim: int = 4096,
|
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
|
rope_theta: int = 10000,
|
|
scale_rope: bool = True,
|
|
dense_moe_strategy: str = "leave_first_three_blocks_dense",
|
|
num_experts: int = 64,
|
|
moe_intermediate_dim: int = 1344,
|
|
capacity_factors: list = None,
|
|
use_sigmoid: bool = False,
|
|
route_scale: float = 2.5,
|
|
use_grouped_mm: bool = False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.patch_size = patch_size
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels or in_channels
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
# Determine which layers are MoE
|
|
if capacity_factors is not None:
|
|
moe_layers = [cf > 0 for cf in capacity_factors]
|
|
else:
|
|
moe_layers = [self._is_moe_layer(dense_moe_strategy, i, num_layers) for i in range(num_layers)]
|
|
self.moe_layers = moe_layers
|
|
|
|
# RoPE
|
|
self.pos_embed = NucleusMoEEmbedRope(
|
|
theta=rope_theta,
|
|
axes_dim=list(axes_dims_rope),
|
|
scale_rope=scale_rope,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Timestep embedding
|
|
self.time_text_embed = NucleusMoETimestepProjEmbeddings(
|
|
embedding_dim=self.inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Input projections
|
|
self.txt_norm = operations.RMSNorm(joint_attention_dim, eps=1e-6, dtype=dtype, device=device)
|
|
self.img_in = operations.Linear(in_channels * patch_size * patch_size, self.inner_dim, dtype=dtype, device=device)
|
|
|
|
# Transformer blocks
|
|
self.transformer_blocks = nn.ModuleList([
|
|
NucleusMoEImageTransformerBlock(
|
|
dim=self.inner_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
num_kv_heads=num_key_value_heads,
|
|
joint_attention_dim=joint_attention_dim,
|
|
is_moe=moe_layers[i],
|
|
num_experts=num_experts,
|
|
moe_intermediate_dim=moe_intermediate_dim,
|
|
capacity_factor=capacity_factors[i] if capacity_factors is not None else (2.0 if moe_layers[i] else 0.0),
|
|
use_sigmoid=use_sigmoid,
|
|
route_scale=route_scale,
|
|
use_grouped_mm=use_grouped_mm,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
for i in range(num_layers)
|
|
])
|
|
|
|
# Output
|
|
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, norm_eps=1e-6, dtype=dtype, device=device, operations=operations)
|
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False, dtype=dtype, device=device)
|
|
|
|
@staticmethod
|
|
def _is_moe_layer(strategy, layer_idx, num_layers):
|
|
if strategy == "leave_first_three_and_last_block_dense":
|
|
return layer_idx >= 3 and layer_idx < num_layers - 1
|
|
elif strategy == "leave_first_three_blocks_dense":
|
|
return layer_idx >= 3
|
|
elif strategy == "leave_first_block_dense":
|
|
return layer_idx >= 1
|
|
elif strategy == "all_moe":
|
|
return True
|
|
elif strategy == "all_dense":
|
|
return False
|
|
return True
|
|
|
|
@staticmethod
|
|
def _normalize_attention_mask(attention_mask, dtype):
|
|
if attention_mask is None:
|
|
return None
|
|
if attention_mask.ndim > 2:
|
|
attention_mask = attention_mask.reshape(attention_mask.shape[0], -1)
|
|
|
|
if not torch.is_floating_point(attention_mask):
|
|
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
|
|
|
if torch.all((attention_mask == 0) | (attention_mask == 1)):
|
|
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
|
|
|
return attention_mask.to(dtype)
|
|
|
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
|
bs, c, t, h, w = x.shape
|
|
patch_size = self.patch_size
|
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
|
orig_shape = hidden_states.shape
|
|
hidden_states = hidden_states.view(
|
|
orig_shape[0], orig_shape[1],
|
|
orig_shape[-3], orig_shape[-2] // 2, 2,
|
|
orig_shape[-1] // 2, 2,
|
|
)
|
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
|
hidden_states = hidden_states.reshape(
|
|
orig_shape[0],
|
|
orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2),
|
|
orig_shape[1] * 4,
|
|
)
|
|
t_len = t
|
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
|
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
|
|
|
return hidden_states, (t_len, h_len, w_len), orig_shape
|
|
|
|
def forward(self, x, timestep, context, attention_mask=None, transformer_options={}, **kwargs):
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(
|
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
|
),
|
|
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
|
|
|
def _forward(
|
|
self,
|
|
x,
|
|
timesteps,
|
|
context,
|
|
attention_mask=None,
|
|
transformer_options={},
|
|
control=None,
|
|
**kwargs,
|
|
):
|
|
encoder_hidden_states = context
|
|
encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype)
|
|
|
|
block_attention_kwargs = {}
|
|
if encoder_hidden_states_mask is not None:
|
|
block_attention_kwargs["attention_mask"] = encoder_hidden_states_mask
|
|
|
|
hidden_states, (t_len, h_len, w_len), orig_shape = self.process_img(x)
|
|
num_embeds = hidden_states.shape[1]
|
|
|
|
# Compute RoPE
|
|
img_freqs, txt_freqs = self.pos_embed(
|
|
video_fhw=[(t_len, h_len, w_len)],
|
|
device=x.device,
|
|
max_txt_seq_len=encoder_hidden_states.shape[1],
|
|
)
|
|
image_rotary_emb = (img_freqs, txt_freqs)
|
|
|
|
# Project inputs
|
|
hidden_states = self.img_in(hidden_states)
|
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
|
|
temb = self.time_text_embed(timesteps, hidden_states)
|
|
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
patches = transformer_options.get("patches", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
if "post_input" in patches:
|
|
for p in patches["post_input"]:
|
|
out = p({
|
|
"img": hidden_states,
|
|
"txt": encoder_hidden_states,
|
|
"transformer_options": transformer_options,
|
|
})
|
|
hidden_states = out["img"]
|
|
encoder_hidden_states = out["txt"]
|
|
|
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
|
transformer_options["block_type"] = "single"
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
transformer_options["block_index"] = i
|
|
|
|
if ("single_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
out = block(
|
|
hidden_states=args["img"],
|
|
encoder_hidden_states=args["txt"],
|
|
temb=args["vec"],
|
|
image_rotary_emb=args["pe"],
|
|
attention_kwargs=block_attention_kwargs,
|
|
transformer_options=args["transformer_options"],
|
|
)
|
|
return {"img": out, "txt": args["txt"]}
|
|
|
|
out = blocks_replace[("single_block", i)](
|
|
{"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options},
|
|
{"original_block": block_wrap},
|
|
)
|
|
hidden_states = out["img"]
|
|
encoder_hidden_states = out.get("txt", encoder_hidden_states)
|
|
else:
|
|
hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=temb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
attention_kwargs=block_attention_kwargs,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
if "single_block" in patches:
|
|
for p in patches["single_block"]:
|
|
out = p({
|
|
"img": hidden_states,
|
|
"txt": encoder_hidden_states,
|
|
"x": x,
|
|
"block_index": i,
|
|
"transformer_options": transformer_options,
|
|
})
|
|
hidden_states = out["img"]
|
|
encoder_hidden_states = out.get("txt", encoder_hidden_states)
|
|
|
|
if control is not None:
|
|
control_i = control.get("input")
|
|
if i < len(control_i):
|
|
add = control_i[i]
|
|
if add is not None:
|
|
hidden_states[:, :add.shape[1]] += add
|
|
|
|
# Output
|
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
# Unpack
|
|
hidden_states = hidden_states[:, :num_embeds].view(
|
|
orig_shape[0], orig_shape[-3],
|
|
orig_shape[-2] // 2, orig_shape[-1] // 2,
|
|
orig_shape[1], 2, 2,
|
|
)
|
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
|
# Diffusers negates Nucleus predictions before FlowMatchEulerDiscreteScheduler.step().
|
|
return -hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|