Initial Microsoft Lens support

This commit is contained in:
kijai 2026-05-23 17:43:25 +03:00
parent d80fcafee7
commit 5ecaf09544
8 changed files with 1514 additions and 11 deletions

578
comfy/ldm/lens/model.py Normal file
View File

@ -0,0 +1,578 @@
"""Lens denoising transformer (DiT)"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
return x_out.type_as(x)
class LensEmbedRope(nn.Module):
"""Frame/H/W axial RoPE shared between image and text streams."""
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None:
super().__init__()
self.theta = theta
self.axes_dim = list(axes_dim)
self.scale_rope = scale_rope
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
# Plain attributes (not buffers): register_buffer strips complex imag.
self.pos_freqs = torch.cat(
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
)
self.neg_freqs = torch.cat(
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
)
self.rope_cache: Dict[str, torch.Tensor] = {}
@staticmethod
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
assert dim % 2 == 0
freqs = torch.outer(
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
)
return torch.polar(torch.ones_like(freqs), freqs)
def forward(
self,
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
txt_seq_lens: Union[List[int], int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
if not isinstance(txt_seq_lens, list):
txt_seq_lens = [txt_seq_lens]
assert len(video_fhw) == 1, "video_fhw must have length 1"
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}_{device}"
video_freq = self.rope_cache.get(rope_key)
if video_freq is None:
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
self.rope_cache[rope_key] = video_freq
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
vid_freqs.append(video_freq)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
return torch.cat(vid_freqs, dim=0), txt_freqs
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
).view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
).view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class _TimestepEmbedder(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = F.silu(x)
return self.linear_2(x)
class LensTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
proj = _lens_time_proj(timestep, 256)
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
class GateMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
class LensJointAttention(nn.Module):
"""Joint image+text attention with fused QKV per stream."""
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
dim_head: int = 64,
heads: int = 8,
out_dim: Optional[int] = None,
eps: float = 1e-5,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.heads = self.inner_dim // dim_head
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
# ModuleList([Linear, Identity]) for state-dict key compatibility.
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
nn.Identity(),
])
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_img, _ = hidden_states.shape
seq_txt = encoder_hidden_states.shape[1]
# image stream
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k)
img_v = img_v.contiguous()
del img_qkv
# text stream
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k)
txt_v = txt_v.contiguous()
del txt_qkv
img_freqs, txt_freqs = image_rotary_emb
if img_freqs.shape[0] < seq_img:
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
img_freqs = img_freqs[:seq_img]
img_q = apply_rotary_emb_lens(img_q, img_freqs)
img_k = apply_rotary_emb_lens(img_k, img_freqs)
if seq_txt > 0:
if txt_freqs.shape[0] < seq_txt:
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
txt_freqs = txt_freqs[:seq_txt]
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
raise ValueError(
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
)
attention_mask = attention_mask.to(q.dtype)
out = optimized_attention(
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
transformer_options=transformer_options,
)
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
txt_out = self.to_add_out(out[:, seq_img:, :])
return img_out, txt_out
class LensTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
rms_norm: bool = True,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.attn = LensJointAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
eps=1e-5,
dtype=dtype,
device=device,
operations=operations,
)
if rms_norm:
NormCls = operations.RMSNorm
norm_kwargs = {}
else:
NormCls = operations.LayerNorm
norm_kwargs = {"elementwise_affine": False}
mlp_hidden = int(dim / 3 * 8)
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
@staticmethod
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class _AdaLayerNormContinuousNoAffine(nn.Module):
"""AdaLayerNormContinuous(elementwise_affine=False).
The reference uses ``scale, shift = chunk(2)`` (scale first) opposite
to Flux's ``LastLayer``.
"""
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear = operations.Linear(
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
)
self.eps = eps
self.embedding_dim = embedding_dim
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
emb = self.linear(F.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=-1)
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class LensTransformer2DModel(nn.Module):
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 128,
out_channels: Optional[int] = 32,
num_layers: int = 48,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
enc_hidden_dim: int = 2880,
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
rms_norm: bool = True,
multi_layer_encoder_feature: bool = True,
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
image_model=None, # unused; accepted for detection-side configs.
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.multi_layer_encoder_feature = multi_layer_encoder_feature
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
if self.multi_layer_encoder_feature:
self.txt_norm = nn.ModuleList(
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
for _ in self.selected_layer_index]
)
self.txt_in = operations.Linear(
enc_hidden_dim * len(self.selected_layer_index),
self.inner_dim, bias=True, dtype=dtype, device=device,
)
else:
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
LensTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=1e-6,
rms_norm=rms_norm,
dtype=dtype, device=device, operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = _AdaLayerNormContinuousNoAffine(
self.inner_dim, self.inner_dim, eps=1e-6,
dtype=dtype, device=device, operations=operations,
)
self.proj_out = operations.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
dtype=dtype, device=device,
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
control: Optional[Dict[str, Any]] = None,
**kwargs,
) -> torch.Tensor:
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
if transformer_options is None:
transformer_options = {}
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
B, C, h, w = x.shape
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
if self.multi_layer_encoder_feature:
L = len(self.selected_layer_index)
enc_dim = context.shape[-1] // L
encoder_hidden_states = list(
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
)
text_seq_len = encoder_hidden_states[0].shape[1]
else:
encoder_hidden_states = context
text_seq_len = context.shape[1]
if attention_mask is None:
attention_mask = torch.ones(
(B, text_seq_len), dtype=torch.bool, device=x.device
)
img_len = h * w
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.multi_layer_encoder_feature:
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
encoder_hidden_states = torch.cat(normed, dim=-1)
else:
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
image_rotary_emb=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)](
{
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": image_rotary_emb,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
encoder_hidden_states = out["txt"]
hidden_states = out["img"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=joint_mask,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None:
control_i = control.get("input")
if control_i is not None and i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
out = self.proj_out(hidden_states)
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
@staticmethod
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
if text_mask.dtype != torch.bool:
text_mask = text_mask.bool()
bsz = text_mask.shape[0]
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
joint = torch.cat([img_ones, text_mask], dim=1)
additive = torch.zeros_like(joint, dtype=torch.float32)
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
return additive[:, None, None, :]

View File

@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lens.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
@ -1058,6 +1059,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lens(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(
model_config, model_type, device=device,
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
)
def encode_adm(self, **kwargs):
return None # Lens has no pooled/ADM conditioning.
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)

View File

@ -755,6 +755,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
if multi_layer:
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
# Indices are TE-side; the DiT just consumes L layers in order.
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
else:
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
selected_layer_index = (0,)
return {
"image_model": "lens",
"in_channels": img_in_w.shape[1],
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
"enc_hidden_dim": enc_hidden_dim,
"multi_layer_encoder_feature": multi_layer,
"selected_layer_index": selected_layer_index,
}
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"

View File

@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.model_patcher
import comfy.lora
@ -1269,6 +1270,7 @@ class CLIPType(Enum):
FLUX2 = 25
LONGCAT_IMAGE = 26
COGVIDEOX = 27
LENS = 28
@ -1321,6 +1323,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
def detect_te_model(sd):
@ -1362,6 +1365,12 @@ def detect_te_model(sd):
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "model.layers.0.self_attn.sinks" in sd and (
"model.layers.0.mlp.experts.gate_up_proj" in sd
or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd
):
return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
@ -1544,6 +1553,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B:
mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data)
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4)
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")

View File

@ -829,6 +829,50 @@ class Flux2(Flux):
return None
class Lens(supported_models_base.BASE):
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
unet_config = {
"image_model": "lens",
}
sampling_settings = {
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint)
if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict:
detect["mxfp4_runtime"] = True
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect),
)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(),
)
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@ -2096,6 +2140,7 @@ models = [
Omnigen2,
QwenImage,
Flux2,
Lens,
Kandinsky5Image,
Kandinsky5,
Anima,

View File

@ -0,0 +1,789 @@
"""GPT-OSS text encoder for Lens."""
from __future__ import annotations
import gc
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@dataclass
class GptOss20BConfig:
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
num_hidden_layers: int = 24
num_attention_heads: int = 64
num_key_value_heads: int = 8
head_dim: int = 64
num_local_experts: int = 32
num_experts_per_tok: int = 4
sliding_window: int = 128
original_max_position_embeddings: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32.0
rope_beta_fast: float = 32.0
rope_beta_slow: float = 1.0
rope_truncate: bool = False
rms_norm_eps: float = 1e-5
attention_bias: bool = True
layer_types: Optional[List[str]] = None
moe_alpha: float = 1.702
moe_limit: float = 7.0
def __post_init__(self):
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 2 else "full_attention"
for i in range(self.num_hidden_layers)
]
def _yarn_inv_freq(
head_dim: int,
base: float,
factor: float,
beta_fast: float,
beta_slow: float,
original_max_position_embeddings: int,
truncate: bool,
device=None,
) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim
def find_correction_dim(num_rotations: float) -> float:
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def find_correction_range() -> tuple[float, float]:
low = find_correction_dim(beta_fast)
high = find_correction_dim(beta_slow)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
if min_ == max_:
max_ += 0.001
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
return torch.clamp(linear, 0, 1)
def get_mscale(scale: float) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
attention_scaling = get_mscale(factor)
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range()
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
return inv_freq, attention_scaling
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
pos_e = position_ids[:, None, :].float()
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
sin_split = sin.shape[-1] // 2
return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor],
num_heads: int,
num_kv_groups: int,
) -> torch.Tensor:
"""Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute
nothing to the output. We fake this by appending one zero k/v position and
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
if attention_mask is not None:
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
else:
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
mask = torch.cat([mask_left, sinks_col], dim=-1)
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
class GptOssAttention(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
bias = config.attention_bias
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
B, S, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
return self.o_proj(out)
# Mixture of Experts
class GptOssTopKRouter(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
# Raw Parameters (not Linear) to match HF state-dict keys.
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
logits = F.linear(hidden_states, self.weight, self.bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only (matches transformers), not all experts.
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx
class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.alpha = config.moe_alpha
self.limit = config.moe_limit
E = self.num_experts
H = self.hidden_size
I = self.intermediate_size
self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype))
self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype))
self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype))
self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype))
def switch_to_mxfp4(self, device=None):
"""Swap bf16 weight Parameters for uint8 MXFP4 packed buffers.
On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]``
uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]``
for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``.
"""
E, H, I = self.num_experts, self.hidden_size, self.intermediate_size
if H % 32 != 0 or I % 32 != 0:
raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}")
del self.gate_up_proj
del self.down_proj
G_up = H // 32
G_down = I // 32
self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device))
self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device))
self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device))
self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device))
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2]
up = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return (up + 1) * glu
@staticmethod
def _dequant_one_expert(
blocks_e: torch.Tensor,
scales_e: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``."""
D, G, B = blocks_e.shape
val_per_row = G * 32
lut = _fp4_lut(dtype, blocks_e.device)
blocks_flat = blocks_e.reshape(D * G, B)
scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1)
dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device)
dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)]
dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)]
torch.ldexp(dec, scales_flat, out=dec)
return dec.view(D, val_per_row).transpose(0, 1)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0]
top_k = router_indices.shape[-1]
H = hidden_states.shape[-1]
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
is_mxfp4 = hasattr(self, "gate_up_proj_blocks")
for ei in expert_hit:
expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx]
if is_mxfp4:
gate_up_w = self._dequant_one_expert(
self.gate_up_proj_blocks[expert_idx],
self.gate_up_proj_scales[expert_idx],
current.dtype,
)
down_w = self._dequant_one_expert(
self.down_proj_blocks[expert_idx],
self.down_proj_scales[expert_idx],
current.dtype,
)
else:
gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False)
down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False)
gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False)
down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False)
gate_up = current @ gate_up_w + gate_up_b
gated = self._apply_gate(gate_up)
expert_out = gated @ down_w + down_b
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos
per_pair[flat_idx] = weighted.to(per_pair.dtype)
return per_pair.view(N, top_k, H).sum(dim=1)
class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape
flat = hidden_states.reshape(-1, H)
scores, idx = self.router(flat)
out = self.experts(flat, idx, scores)
return out.reshape(B, S, H)
# Decoder layer + model
class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx]
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
i = torch.arange(S, device=device).view(-1, 1)
j = torch.arange(S, device=device).view(1, -1)
keep = (j <= i) & (j > i - window)
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
class GptOssModel(nn.Module):
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.config = config
self.dtype = dtype
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# Always build on CPU so the buffer survives meta-device construction.
inv_freq, attn_scaling = _yarn_inv_freq(
head_dim=config.head_dim,
base=config.rope_theta,
factor=config.rope_factor,
beta_fast=config.rope_beta_fast,
beta_slow=config.rope_beta_slow,
original_max_position_embeddings=config.original_max_position_embeddings,
truncate=config.rope_truncate,
device=torch.device("cpu"),
)
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
self.rope_attention_scaling = float(attn_scaling)
@property
def num_layers(self) -> int:
return self.config.num_hidden_layers
def get_input_embeddings(self):
return self.embed_tokens
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
) -> dict[str, torch.Tensor]:
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
masks = {"full_attention": full}
if any(t == "sliding_attention" for t in self.config.layer_types):
masks["sliding_attention"] = _make_sliding_causal_mask(
B, S, self.config.sliding_window, attention_mask, dtype, device
)
return masks
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
B, S = input_ids.shape
device = input_ids.device
dtype = self.dtype
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
capture_layers = list(capture_layers) if capture_layers else None
if capture_layers:
max_layer = max(capture_layers)
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
else:
max_layer = self.config.num_hidden_layers - 1
wanted = None
captured = None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
if wanted is not None and i in wanted:
captured[wanted[i]] = hidden_states
if i >= max_layer:
break
if captured is not None:
return {"hidden_states": captured}
return {"last_hidden_state": self.norm(hidden_states)}
# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512
# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"
def _lens_render_chat(prompt: str) -> str:
"""Render the Lens prompt in GPT-OSS Harmony format."""
return (
f"<|start|>system<|message|>"
f"You are ChatGPT, a large language model trained by OpenAI.\n"
f"Knowledge cutoff: 2024-06\n"
f"Current date: {_LENS_CHAT_DATE}\n\n"
f"Reasoning: medium\n\n"
f"# Valid channels: analysis, commentary, final. "
f"Channel must be included for every message.<|end|>"
f"<|start|>developer<|message|># Instructions\n\n"
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
f"<|start|>user<|message|>{prompt}<|end|>"
f"<|start|>assistant<|channel|>analysis<|message|>"
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>"
)
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
class _GptOssRawTokenizer:
"""Raw ``tokenizers.Tokenizer`` wrapper.
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
(``tokenizer_json`` key) rather than as a committed file. Extracted
it in ``sd.py`` and passes it here via ``tokenizer_data``.
"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
if tokenizer_json_bytes is None:
raise ValueError(
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
"embeds the tokenizer."
)
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@classmethod
def from_pretrained(cls, tokenizer_data, **kwargs):
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
def __call__(self, text):
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
def get_vocab(self):
return self.tokenizer.get_vocab()
def convert_tokens_to_ids(self, tokens):
return [self.tokenizer.token_to_id(t) for t in tokens]
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
tokenizer_json_data = None
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(
tokenizer_json,
embedding_directory=embedding_directory,
pad_with_end=False,
embedding_size=2880,
embedding_key="gpt_oss",
tokenizer_class=_GptOssRawTokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_left=False,
disable_weights=True,
tokenizer_data=tokenizer_data,
)
self.pad_token_id = _LENS_PAD_TOKEN_ID
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
if not text or not text.strip():
return [[]]
rendered = _lens_render_chat(text)
ids = self.tokenizer(rendered)["input_ids"]
if len(ids) > LENS_MAX_TOKENS:
ids = ids[:LENS_MAX_TOKENS]
return [[(int(t), 1.0) for t in ids]]
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
class LensTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="gpt_oss",
tokenizer=LensGptOssTokenizer,
)
# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa).
_FP4_VALUES: Tuple[float, ...] = (
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
)
_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {}
def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Cached per (dtype, device) FP4 lookup table — avoids per-call allocation."""
key = (dtype, str(device))
lut = _FP4_LUT_CACHE.get(key)
if lut is None:
lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device)
_FP4_LUT_CACHE[key] = lut
return lut
def _safe_dequant_moe_tensor(
blocks: torch.Tensor,
scales: torch.Tensor,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 4096,
) -> torch.Tensor:
"""Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``.
Allocates the output in its final transposed layout and writes in chunks
"""
blocks = blocks.to(torch.uint8)
scales = scales.to(torch.int32) - 127
assert blocks.shape[:-1] == scales.shape, (
f"{blocks.shape[:-1]=} does not match {scales.shape=}"
)
*prefix_shape, G, B = blocks.shape
if len(prefix_shape) != 2:
raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}")
E, D = prefix_shape
val_per_row = G * B * 2 # this is H after dequant
rows_total = E * D * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
lut = _fp4_lut(dtype, blocks.device)
out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device)
for e in range(E):
for d0 in range(0, D, rows_per_chunk):
d1 = min(d0 + rows_per_chunk, D)
r0 = e * D * G + d0 * G
r1 = e * D * G + d1 * G
blk = blocks[r0:r1]
exp = scales[r0:r1]
dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device)
dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)]
dec[:, 1::2] = lut[(blk >> 4).to(torch.long)]
torch.ldexp(dec, exp, out=dec)
out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1)
del blk, exp, dec
return out
def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
"""Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place."""
pairs: List[Tuple[str, str, str]] = []
for k in list(sd.keys()):
if k.endswith("_blocks"):
stem = k[: -len("_blocks")]
sk = stem + "_scales"
if sk in sd:
pairs.append((stem, k, sk))
if not pairs:
return sd
logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype)
for stem, bk, sk in pairs:
blocks = sd.pop(bk)
scales = sd.pop(sk)
sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype)
del blocks, scales
gc.collect()
return sd
class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **_):
super().__init__()
model_options = dict(model_options or {})
operations = model_options.get("custom_operations")
quant_config = model_options.get("quantization_metadata")
if operations is None:
if quant_config is not None:
operations = comfy.ops.mixed_precision_ops(
quant_config, dtype, full_precision_mm=True
)
else:
operations = comfy.ops.manual_cast
self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple(
model_options.get("selected_layers", LENS_SELECTED_LAYERS)
)
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
# mxfp4_runtime=True keeps experts packed and dequants per hit at forward.
self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
self.dtype = dtype
self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID
for p in self.parameters():
p.requires_grad = False
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
def reset_clip_options(self):
self.execution_device = None
def _gather_tokens(self, token_weight_pairs):
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
pad_id = self._pad_token_id
max_len = max(len(x) for x in ids_list)
device = self.execution_device
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
for i, x in enumerate(ids_list):
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
mask[i, : len(x)] = 1
return ids, mask
def encode_token_weights(self, token_weight_pairs):
# Empty negative: emit zero-length features + zero mask
if all(len(batch) == 0 for batch in token_weight_pairs):
device = self.execution_device
B = len(token_weight_pairs)
L = len(self.selected_layers)
H = self.config.hidden_size
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
layers = out["hidden_states"] # list of L × [B, S, H]
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
offset = self.txt_offset
if stacked.shape[1] > offset:
stacked = stacked[:, offset:].contiguous()
mask_trim = attn_mask[:, offset:]
else:
stacked = stacked[:, :0]
mask_trim = attn_mask[:, :0]
B, S, L, H = stacked.shape
flat = stacked.reshape(B, S, L * H)
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
return flat, None, extra
def load_sd(self, sd):
if any(k.startswith("model.") for k in sd):
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
sd.pop("lm_head.weight", None)
if self.mxfp4_runtime:
device = next(self.transformer.parameters()).device
for layer in self.transformer.layers:
layer.mlp.experts.switch_to_mxfp4(device=device)
else:
sd = _dequant_mxfp4_state_dict(sd, self.dtype)
return self.transformer.load_state_dict(sd, strict=False, assign=True)
class LensTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False):
class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {})
if llama_quantization_metadata is not None:
mo["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
if mxfp4_runtime:
mo["mxfp4_runtime"] = True
super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_

View File

@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
"pre_cfg",
default=False,
optional=True,
tooltip=(
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
"without clamping (can amplify). Matches the norm-scaled CFG used by "
"models like Lens. Default false keeps the original post-CFG x0-space "
"attenuate-only behavior."
),
),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
if pre_cfg:
def cfg_norm_pre(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
rescale = torch.where(
comb_norm > 0,
cond_norm / comb_norm.clamp_min(1e-12),
torch.ones_like(comb_norm),
)
rescaled = comb * rescale
# strength blends back toward standard linear CFG (1.0 = full rescale).
if strength != 1.0:
rescaled = strength * rescaled + (1.0 - strength) * comb
return rescaled
m.set_model_sampler_cfg_function(cfg_norm_pre)
else:
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)

View File

@ -961,7 +961,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -971,7 +971,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b (MXFP4 dequant on load)"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)