mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
Initial Microsoft Lens support
This commit is contained in:
parent
d80fcafee7
commit
5ecaf09544
578
comfy/ldm/lens/model.py
Normal file
578
comfy/ldm/lens/model.py
Normal file
@ -0,0 +1,578 @@
|
||||
"""Lens denoising transformer (DiT)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||
|
||||
|
||||
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class LensEmbedRope(nn.Module):
|
||||
"""Frame/H/W axial RoPE shared between image and text streams."""
|
||||
|
||||
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = list(axes_dim)
|
||||
self.scale_rope = scale_rope
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
# Plain attributes (not buffers): register_buffer strips complex imag.
|
||||
self.pos_freqs = torch.cat(
|
||||
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
|
||||
)
|
||||
self.rope_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
@staticmethod
|
||||
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
|
||||
)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
|
||||
txt_seq_lens: Union[List[int], int],
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
if not isinstance(txt_seq_lens, list):
|
||||
txt_seq_lens = [txt_seq_lens]
|
||||
assert len(video_fhw) == 1, "video_fhw must have length 1"
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}_{device}"
|
||||
video_freq = self.rope_cache.get(rope_key)
|
||||
if video_freq is None:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
|
||||
self.rope_cache[rope_key] = video_freq
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
return torch.cat(vid_freqs, dim=0), txt_freqs
|
||||
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
).view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
|
||||
).view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class _TimestepEmbedder(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = F.silu(x)
|
||||
return self.linear_2(x)
|
||||
|
||||
|
||||
class LensTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
proj = _lens_time_proj(timestep, 256)
|
||||
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
|
||||
class GateMLP(nn.Module):
|
||||
"""SwiGLU MLP."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
|
||||
|
||||
|
||||
class LensJointAttention(nn.Module):
|
||||
"""Joint image+text attention with fused QKV per stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
out_dim: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.heads = self.inner_dim // dim_head
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# ModuleList([Linear, Identity]) for state-dict key compatibility.
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.Identity(),
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bsz, seq_img, _ = hidden_states.shape
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# image stream
|
||||
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
|
||||
img_q, img_k, img_v = img_qkv.unbind(dim=2)
|
||||
img_q = self.norm_q(img_q)
|
||||
img_k = self.norm_k(img_k)
|
||||
img_v = img_v.contiguous()
|
||||
del img_qkv
|
||||
|
||||
# text stream
|
||||
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
|
||||
txt_q = self.norm_added_q(txt_q)
|
||||
txt_k = self.norm_added_k(txt_k)
|
||||
txt_v = txt_v.contiguous()
|
||||
del txt_qkv
|
||||
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
if img_freqs.shape[0] < seq_img:
|
||||
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
|
||||
img_freqs = img_freqs[:seq_img]
|
||||
img_q = apply_rotary_emb_lens(img_q, img_freqs)
|
||||
img_k = apply_rotary_emb_lens(img_k, img_freqs)
|
||||
if seq_txt > 0:
|
||||
if txt_freqs.shape[0] < seq_txt:
|
||||
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
|
||||
txt_freqs = txt_freqs[:seq_txt]
|
||||
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
|
||||
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
|
||||
|
||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||
del img_v, txt_v
|
||||
|
||||
if attention_mask is not None:
|
||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||
if attention_mask.shape != expected:
|
||||
raise ValueError(
|
||||
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
|
||||
)
|
||||
attention_mask = attention_mask.to(q.dtype)
|
||||
|
||||
out = optimized_attention(
|
||||
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
|
||||
txt_out = self.to_add_out(out[:, seq_img:, :])
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class LensTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
rms_norm: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = LensJointAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if rms_norm:
|
||||
NormCls = operations.RMSNorm
|
||||
norm_kwargs = {}
|
||||
else:
|
||||
NormCls = operations.LayerNorm
|
||||
norm_kwargs = {"elementwise_affine": False}
|
||||
|
||||
mlp_hidden = int(dim / 3 * 8)
|
||||
|
||||
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
|
||||
img_attn, txt_attn = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class _AdaLayerNormContinuousNoAffine(nn.Module):
|
||||
"""AdaLayerNormContinuous(elementwise_affine=False).
|
||||
|
||||
The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
|
||||
to Flux's ``LastLayer``.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
|
||||
dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(
|
||||
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.eps = eps
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(F.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=-1)
|
||||
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class LensTransformer2DModel(nn.Module):
|
||||
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_layers: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
enc_hidden_dim: int = 2880,
|
||||
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
|
||||
rms_norm: bool = True,
|
||||
multi_layer_encoder_feature: bool = True,
|
||||
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
|
||||
image_model=None, # unused; accepted for detection-side configs.
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.multi_layer_encoder_feature = multi_layer_encoder_feature
|
||||
self.selected_layer_index = list(selected_layer_index)
|
||||
self.dtype = dtype
|
||||
|
||||
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
self.txt_norm = nn.ModuleList(
|
||||
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
for _ in self.selected_layer_index]
|
||||
)
|
||||
self.txt_in = operations.Linear(
|
||||
enc_hidden_dim * len(self.selected_layer_index),
|
||||
self.inner_dim, bias=True, dtype=dtype, device=device,
|
||||
)
|
||||
else:
|
||||
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
LensTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=1e-6,
|
||||
rms_norm=rms_norm,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = _AdaLayerNormContinuousNoAffine(
|
||||
self.inner_dim, self.inner_dim, eps=1e-6,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.proj_out = operations.Linear(
|
||||
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
control: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
B, C, h, w = x.shape
|
||||
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
L = len(self.selected_layer_index)
|
||||
enc_dim = context.shape[-1] // L
|
||||
encoder_hidden_states = list(
|
||||
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
|
||||
)
|
||||
text_seq_len = encoder_hidden_states[0].shape[1]
|
||||
else:
|
||||
encoder_hidden_states = context
|
||||
text_seq_len = context.shape[1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(B, text_seq_len), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
img_len = h * w
|
||||
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
|
||||
encoder_hidden_states = torch.cat(normed, dim=-1)
|
||||
else:
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
attention_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"vec": temb,
|
||||
"pe": image_rotary_emb,
|
||||
"attn_mask": joint_mask,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
encoder_hidden_states = out["txt"]
|
||||
hidden_states = out["img"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=joint_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"x": x,
|
||||
"block_index": i,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None:
|
||||
control_i = control.get("input")
|
||||
if control_i is not None and i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
out = self.proj_out(hidden_states)
|
||||
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
|
||||
if text_mask.dtype != torch.bool:
|
||||
text_mask = text_mask.bool()
|
||||
bsz = text_mask.shape[0]
|
||||
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
|
||||
joint = torch.cat([img_ones, text_mask], dim=1)
|
||||
additive = torch.zeros_like(joint, dtype=torch.float32)
|
||||
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
|
||||
return additive[:, None, None, :]
|
||||
@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lens.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
@ -1058,6 +1059,27 @@ class Flux2(Flux):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Lens(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(
|
||||
model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
|
||||
)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None # Lens has no pooled/ADM conditioning.
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
|
||||
@ -755,6 +755,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
|
||||
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
|
||||
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
|
||||
if multi_layer:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
|
||||
# Indices are TE-side; the DiT just consumes L layers in order.
|
||||
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
|
||||
else:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
|
||||
selected_layer_index = (0,)
|
||||
|
||||
return {
|
||||
"image_model": "lens",
|
||||
"in_channels": img_in_w.shape[1],
|
||||
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
|
||||
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
|
||||
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
|
||||
"enc_hidden_dim": enc_hidden_dim,
|
||||
"multi_layer_encoder_feature": multi_layer,
|
||||
"selected_layer_index": selected_layer_index,
|
||||
}
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
|
||||
14
comfy/sd.py
14
comfy/sd.py
@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1269,6 +1270,7 @@ class CLIPType(Enum):
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
|
||||
|
||||
|
||||
@ -1321,6 +1323,7 @@ class TEModel(Enum):
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1362,6 +1365,12 @@ def detect_te_model(sd):
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||
if "model.layers.0.self_attn.sinks" in sd and (
|
||||
"model.layers.0.mlp.experts.gate_up_proj" in sd
|
||||
or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd
|
||||
):
|
||||
return TEModel.GPT_OSS_20B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
@ -1544,6 +1553,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.GPT_OSS_20B:
|
||||
mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data)
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4)
|
||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
|
||||
|
||||
@ -829,6 +829,50 @@ class Flux2(Flux):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Lens(supported_models_base.BASE):
|
||||
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
|
||||
|
||||
unet_config = {
|
||||
"image_model": "lens",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
for hint in ("gpt_oss.transformer.", ""):
|
||||
full_prefix = "{}{}".format(pref, hint)
|
||||
if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||
if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict:
|
||||
detect["mxfp4_runtime"] = True
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||
)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(),
|
||||
)
|
||||
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -2096,6 +2140,7 @@ models = [
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
|
||||
789
comfy/text_encoders/gpt_oss.py
Normal file
789
comfy/text_encoders/gpt_oss.py
Normal file
@ -0,0 +1,789 @@
|
||||
"""GPT-OSS text encoder for Lens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class GptOss20BConfig:
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 64
|
||||
num_local_experts: int = 32
|
||||
num_experts_per_tok: int = 4
|
||||
sliding_window: int = 128
|
||||
original_max_position_embeddings: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_factor: float = 32.0
|
||||
rope_beta_fast: float = 32.0
|
||||
rope_beta_slow: float = 1.0
|
||||
rope_truncate: bool = False
|
||||
rms_norm_eps: float = 1e-5
|
||||
attention_bias: bool = True
|
||||
layer_types: Optional[List[str]] = None
|
||||
moe_alpha: float = 1.702
|
||||
moe_limit: float = 7.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if (i + 1) % 2 else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
def _yarn_inv_freq(
|
||||
head_dim: int,
|
||||
base: float,
|
||||
factor: float,
|
||||
beta_fast: float,
|
||||
beta_slow: float,
|
||||
original_max_position_embeddings: int,
|
||||
truncate: bool,
|
||||
device=None,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||||
dim = head_dim
|
||||
|
||||
def find_correction_dim(num_rotations: float) -> float:
|
||||
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def find_correction_range() -> tuple[float, float]:
|
||||
low = find_correction_dim(beta_fast)
|
||||
high = find_correction_dim(beta_slow)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
|
||||
if min_ == max_:
|
||||
max_ += 0.001
|
||||
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
|
||||
return torch.clamp(linear, 0, 1)
|
||||
|
||||
def get_mscale(scale: float) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
attention_scaling = get_mscale(factor)
|
||||
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range()
|
||||
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
|
||||
return inv_freq, attention_scaling
|
||||
|
||||
|
||||
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
pos_e = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||||
|
||||
|
||||
def _attention_with_sinks(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
sinks: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
num_heads: int,
|
||||
num_kv_groups: int,
|
||||
) -> torch.Tensor:
|
||||
"""Attention with per-head sinks.
|
||||
|
||||
Sinks add a learned term to each row's softmax denominator but contribute
|
||||
nothing to the output. We fake this by appending one zero k/v position and
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
|
||||
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
|
||||
if attention_mask is not None:
|
||||
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
|
||||
else:
|
||||
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
|
||||
mask = torch.cat([mask_left, sinks_col], dim=-1)
|
||||
|
||||
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
|
||||
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
||||
|
||||
bias = config.attention_bias
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
|
||||
B, S, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
# Mixture of Experts
|
||||
|
||||
class GptOssTopKRouter(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
# Raw Parameters (not Linear) to match HF state-dict keys.
|
||||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
logits = F.linear(hidden_states, self.weight, self.bias)
|
||||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||||
# Softmax over top-k slice only (matches transformers), not all experts.
|
||||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||||
return scores, top_idx
|
||||
|
||||
|
||||
class GptOssExperts(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.alpha = config.moe_alpha
|
||||
self.limit = config.moe_limit
|
||||
|
||||
E = self.num_experts
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype))
|
||||
self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype))
|
||||
self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype))
|
||||
|
||||
def switch_to_mxfp4(self, device=None):
|
||||
"""Swap bf16 weight Parameters for uint8 MXFP4 packed buffers.
|
||||
|
||||
On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]``
|
||||
uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]``
|
||||
for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``.
|
||||
"""
|
||||
E, H, I = self.num_experts, self.hidden_size, self.intermediate_size
|
||||
if H % 32 != 0 or I % 32 != 0:
|
||||
raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}")
|
||||
del self.gate_up_proj
|
||||
del self.down_proj
|
||||
G_up = H // 32
|
||||
G_down = I // 32
|
||||
self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device))
|
||||
|
||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||
gate = gate_up[..., ::2]
|
||||
up = gate_up[..., 1::2]
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
return (up + 1) * glu
|
||||
|
||||
@staticmethod
|
||||
def _dequant_one_expert(
|
||||
blocks_e: torch.Tensor,
|
||||
scales_e: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``."""
|
||||
D, G, B = blocks_e.shape
|
||||
val_per_row = G * 32
|
||||
lut = _fp4_lut(dtype, blocks_e.device)
|
||||
blocks_flat = blocks_e.reshape(D * G, B)
|
||||
scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1)
|
||||
dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device)
|
||||
dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)]
|
||||
dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)]
|
||||
torch.ldexp(dec, scales_flat, out=dec)
|
||||
return dec.view(D, val_per_row).transpose(0, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||
N = hidden_states.shape[0]
|
||||
top_k = router_indices.shape[-1]
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
is_mxfp4 = hasattr(self, "gate_up_proj_blocks")
|
||||
|
||||
for ei in expert_hit:
|
||||
expert_idx = int(ei.item())
|
||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current = hidden_states[token_idx]
|
||||
|
||||
if is_mxfp4:
|
||||
gate_up_w = self._dequant_one_expert(
|
||||
self.gate_up_proj_blocks[expert_idx],
|
||||
self.gate_up_proj_scales[expert_idx],
|
||||
current.dtype,
|
||||
)
|
||||
down_w = self._dequant_one_expert(
|
||||
self.down_proj_blocks[expert_idx],
|
||||
self.down_proj_scales[expert_idx],
|
||||
current.dtype,
|
||||
)
|
||||
else:
|
||||
gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False)
|
||||
down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False)
|
||||
|
||||
gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False)
|
||||
down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False)
|
||||
|
||||
gate_up = current @ gate_up_w + gate_up_b
|
||||
gated = self._apply_gate(gate_up)
|
||||
expert_out = gated @ down_w + down_b
|
||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||
|
||||
flat_idx = token_idx * top_k + top_k_pos
|
||||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||||
|
||||
return per_pair.view(N, top_k, H).sum(dim=1)
|
||||
|
||||
|
||||
class GptOssMLP(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
B, S, H = hidden_states.shape
|
||||
flat = hidden_states.reshape(-1, H)
|
||||
scores, idx = self.router(flat)
|
||||
out = self.experts(flat, idx, scores)
|
||||
return out.reshape(B, S, H)
|
||||
|
||||
|
||||
# Decoder layer + model
|
||||
|
||||
class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
i = torch.arange(S, device=device).view(-1, 1)
|
||||
j = torch.arange(S, device=device).view(1, -1)
|
||||
keep = (j <= i) & (j > i - window)
|
||||
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
class GptOssModel(nn.Module):
|
||||
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
|
||||
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
# Always build on CPU so the buffer survives meta-device construction.
|
||||
inv_freq, attn_scaling = _yarn_inv_freq(
|
||||
head_dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
factor=config.rope_factor,
|
||||
beta_fast=config.rope_beta_fast,
|
||||
beta_slow=config.rope_beta_slow,
|
||||
original_max_position_embeddings=config.original_max_position_embeddings,
|
||||
truncate=config.rope_truncate,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
|
||||
self.rope_attention_scaling = float(attn_scaling)
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
|
||||
masks = {"full_attention": full}
|
||||
if any(t == "sliding_attention" for t in self.config.layer_types):
|
||||
masks["sliding_attention"] = _make_sliding_causal_mask(
|
||||
B, S, self.config.sliding_window, attention_mask, dtype, device
|
||||
)
|
||||
return masks
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
|
||||
B, S = input_ids.shape
|
||||
device = input_ids.device
|
||||
dtype = self.dtype
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
|
||||
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
|
||||
|
||||
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
|
||||
|
||||
capture_layers = list(capture_layers) if capture_layers else None
|
||||
if capture_layers:
|
||||
max_layer = max(capture_layers)
|
||||
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
|
||||
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
|
||||
else:
|
||||
max_layer = self.config.num_hidden_layers - 1
|
||||
wanted = None
|
||||
captured = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
|
||||
if wanted is not None and i in wanted:
|
||||
captured[wanted[i]] = hidden_states
|
||||
if i >= max_layer:
|
||||
break
|
||||
|
||||
if captured is not None:
|
||||
return {"hidden_states": captured}
|
||||
return {"last_hidden_state": self.norm(hidden_states)}
|
||||
|
||||
|
||||
# Lens chat-template constants (verbatim from the reference pipeline).
|
||||
_LENS_CHAT_SYSTEM = (
|
||||
"Describe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background."
|
||||
)
|
||||
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
|
||||
LENS_TXT_OFFSET = 97
|
||||
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
|
||||
LENS_MAX_TOKENS = 512
|
||||
|
||||
|
||||
# The reference GPT-OSS Harmony template injects today's date here
|
||||
_LENS_CHAT_DATE = "2026-05-23"
|
||||
|
||||
|
||||
def _lens_render_chat(prompt: str) -> str:
|
||||
"""Render the Lens prompt in GPT-OSS Harmony format."""
|
||||
return (
|
||||
f"<|start|>system<|message|>"
|
||||
f"You are ChatGPT, a large language model trained by OpenAI.\n"
|
||||
f"Knowledge cutoff: 2024-06\n"
|
||||
f"Current date: {_LENS_CHAT_DATE}\n\n"
|
||||
f"Reasoning: medium\n\n"
|
||||
f"# Valid channels: analysis, commentary, final. "
|
||||
f"Channel must be included for every message.<|end|>"
|
||||
f"<|start|>developer<|message|># Instructions\n\n"
|
||||
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
|
||||
f"<|start|>user<|message|>{prompt}<|end|>"
|
||||
f"<|start|>assistant<|channel|>analysis<|message|>"
|
||||
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>"
|
||||
)
|
||||
|
||||
|
||||
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
|
||||
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
|
||||
|
||||
|
||||
class _GptOssRawTokenizer:
|
||||
"""Raw ``tokenizers.Tokenizer`` wrapper.
|
||||
|
||||
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
|
||||
(``tokenizer_json`` key) rather than as a committed file. Extracted
|
||||
it in ``sd.py`` and passes it here via ``tokenizer_data``.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||
from tokenizers import Tokenizer
|
||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||
if tokenizer_json_bytes is None:
|
||||
raise ValueError(
|
||||
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
|
||||
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
|
||||
"embeds the tokenizer."
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, tokenizer_data, **kwargs):
|
||||
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
|
||||
|
||||
def __call__(self, text):
|
||||
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.tokenizer.token_to_id(t) for t in tokens]
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
|
||||
|
||||
|
||||
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
|
||||
tokenizer_json_data = None
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||||
self.tokenizer_json_data = tokenizer_json
|
||||
super().__init__(
|
||||
tokenizer_json,
|
||||
embedding_directory=embedding_directory,
|
||||
pad_with_end=False,
|
||||
embedding_size=2880,
|
||||
embedding_key="gpt_oss",
|
||||
tokenizer_class=_GptOssRawTokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_left=False,
|
||||
disable_weights=True,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
self.pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
|
||||
if not text or not text.strip():
|
||||
return [[]]
|
||||
rendered = _lens_render_chat(text)
|
||||
ids = self.tokenizer(rendered)["input_ids"]
|
||||
if len(ids) > LENS_MAX_TOKENS:
|
||||
ids = ids[:LENS_MAX_TOKENS]
|
||||
return [[(int(t), 1.0) for t in ids]]
|
||||
|
||||
def state_dict(self):
|
||||
if self.tokenizer_json_data is not None:
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
|
||||
class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
tokenizer_data=tokenizer_data,
|
||||
name="gpt_oss",
|
||||
tokenizer=LensGptOssTokenizer,
|
||||
)
|
||||
|
||||
|
||||
# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa).
|
||||
_FP4_VALUES: Tuple[float, ...] = (
|
||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
||||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
||||
)
|
||||
|
||||
|
||||
_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
"""Cached per (dtype, device) FP4 lookup table — avoids per-call allocation."""
|
||||
key = (dtype, str(device))
|
||||
lut = _FP4_LUT_CACHE.get(key)
|
||||
if lut is None:
|
||||
lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device)
|
||||
_FP4_LUT_CACHE[key] = lut
|
||||
return lut
|
||||
|
||||
|
||||
def _safe_dequant_moe_tensor(
|
||||
blocks: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
*,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
rows_per_chunk: int = 4096,
|
||||
) -> torch.Tensor:
|
||||
"""Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``.
|
||||
|
||||
Allocates the output in its final transposed layout and writes in chunks
|
||||
"""
|
||||
blocks = blocks.to(torch.uint8)
|
||||
scales = scales.to(torch.int32) - 127
|
||||
|
||||
assert blocks.shape[:-1] == scales.shape, (
|
||||
f"{blocks.shape[:-1]=} does not match {scales.shape=}"
|
||||
)
|
||||
*prefix_shape, G, B = blocks.shape
|
||||
if len(prefix_shape) != 2:
|
||||
raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}")
|
||||
|
||||
E, D = prefix_shape
|
||||
val_per_row = G * B * 2 # this is H after dequant
|
||||
|
||||
rows_total = E * D * G
|
||||
blocks = blocks.reshape(rows_total, B)
|
||||
scales = scales.reshape(rows_total, 1)
|
||||
|
||||
lut = _fp4_lut(dtype, blocks.device)
|
||||
out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device)
|
||||
|
||||
for e in range(E):
|
||||
for d0 in range(0, D, rows_per_chunk):
|
||||
d1 = min(d0 + rows_per_chunk, D)
|
||||
r0 = e * D * G + d0 * G
|
||||
r1 = e * D * G + d1 * G
|
||||
blk = blocks[r0:r1]
|
||||
exp = scales[r0:r1]
|
||||
dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device)
|
||||
dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)]
|
||||
dec[:, 1::2] = lut[(blk >> 4).to(torch.long)]
|
||||
torch.ldexp(dec, exp, out=dec)
|
||||
out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1)
|
||||
del blk, exp, dec
|
||||
return out
|
||||
|
||||
|
||||
def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
"""Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place."""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
for k in list(sd.keys()):
|
||||
if k.endswith("_blocks"):
|
||||
stem = k[: -len("_blocks")]
|
||||
sk = stem + "_scales"
|
||||
if sk in sd:
|
||||
pairs.append((stem, k, sk))
|
||||
|
||||
if not pairs:
|
||||
return sd
|
||||
|
||||
logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype)
|
||||
for stem, bk, sk in pairs:
|
||||
blocks = sd.pop(bk)
|
||||
scales = sd.pop(sk)
|
||||
sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype)
|
||||
del blocks, scales
|
||||
|
||||
gc.collect()
|
||||
return sd
|
||||
|
||||
|
||||
class LensGptOssClipModel(nn.Module):
|
||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **_):
|
||||
super().__init__()
|
||||
model_options = dict(model_options or {})
|
||||
|
||||
operations = model_options.get("custom_operations")
|
||||
quant_config = model_options.get("quantization_metadata")
|
||||
if operations is None:
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(
|
||||
quant_config, dtype, full_precision_mm=True
|
||||
)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
self.operations = operations
|
||||
|
||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||
self.config = GptOss20BConfig(**cfg_overrides)
|
||||
self.selected_layers = tuple(
|
||||
model_options.get("selected_layers", LENS_SELECTED_LAYERS)
|
||||
)
|
||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||
|
||||
# mxfp4_runtime=True keeps experts packed and dequants per hit at forward.
|
||||
self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False))
|
||||
|
||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
self.execution_device = None
|
||||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.execution_device = None
|
||||
|
||||
def _gather_tokens(self, token_weight_pairs):
|
||||
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
|
||||
pad_id = self._pad_token_id
|
||||
max_len = max(len(x) for x in ids_list)
|
||||
device = self.execution_device
|
||||
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
|
||||
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
|
||||
for i, x in enumerate(ids_list):
|
||||
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
|
||||
mask[i, : len(x)] = 1
|
||||
return ids, mask
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
# Empty negative: emit zero-length features + zero mask
|
||||
if all(len(batch) == 0 for batch in token_weight_pairs):
|
||||
device = self.execution_device
|
||||
B = len(token_weight_pairs)
|
||||
L = len(self.selected_layers)
|
||||
H = self.config.hidden_size
|
||||
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
|
||||
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
|
||||
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
|
||||
|
||||
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
|
||||
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
|
||||
layers = out["hidden_states"] # list of L × [B, S, H]
|
||||
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
|
||||
|
||||
offset = self.txt_offset
|
||||
if stacked.shape[1] > offset:
|
||||
stacked = stacked[:, offset:].contiguous()
|
||||
mask_trim = attn_mask[:, offset:]
|
||||
else:
|
||||
stacked = stacked[:, :0]
|
||||
mask_trim = attn_mask[:, :0]
|
||||
|
||||
B, S, L, H = stacked.shape
|
||||
flat = stacked.reshape(B, S, L * H)
|
||||
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
|
||||
return flat, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if any(k.startswith("model.") for k in sd):
|
||||
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
|
||||
sd.pop("lm_head.weight", None)
|
||||
|
||||
if self.mxfp4_runtime:
|
||||
device = next(self.transformer.parameters()).device
|
||||
for layer in self.transformer.layers:
|
||||
layer.mlp.experts.switch_to_mxfp4(device=device)
|
||||
else:
|
||||
sd = _dequant_mxfp4_state_dict(sd, self.dtype)
|
||||
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
|
||||
class LensTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||
|
||||
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False):
|
||||
class LensTEModel_(LensTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
mo = dict(model_options or {})
|
||||
if llama_quantization_metadata is not None:
|
||||
mo["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if mxfp4_runtime:
|
||||
mo["mxfp4_runtime"] = True
|
||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||
|
||||
return LensTEModel_
|
||||
@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
io.Boolean.Input(
|
||||
"pre_cfg",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip=(
|
||||
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
|
||||
"without clamping (can amplify). Matches the norm-scaled CFG used by "
|
||||
"models like Lens. Default false keeps the original post-CFG x0-space "
|
||||
"attenuate-only behavior."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output(display_name="patched_model")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, strength) -> io.NodeOutput:
|
||||
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
if pre_cfg:
|
||||
def cfg_norm_pre(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
comb = uncond + cond_scale * (cond - uncond)
|
||||
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
|
||||
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
|
||||
rescale = torch.where(
|
||||
comb_norm > 0,
|
||||
cond_norm / comb_norm.clamp_min(1e-12),
|
||||
torch.ones_like(comb_norm),
|
||||
)
|
||||
rescaled = comb * rescale
|
||||
# strength blends back toward standard linear CFG (1.0 = full rescale).
|
||||
if strength != 1.0:
|
||||
rescaled = strength * rescaled + (1.0 - strength) * comb
|
||||
return rescaled
|
||||
m.set_model_sampler_cfg_function(cfg_norm_pre)
|
||||
else:
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -961,7 +961,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -971,7 +971,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b (MXFP4 dequant on load)"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user