Merge remote-tracking branch 'upstream/master' into pixeldit

This commit is contained in:
kijai 2026-05-26 09:04:43 +03:00
commit 46e0b4b232
9 changed files with 1533 additions and 208 deletions

513
comfy/ldm/lens/model.py Normal file
View File

@ -0,0 +1,513 @@
"""Lens denoising transformer (DiT)"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def _lens_position_ids(
frame: int, height: int, width: int, text_seq_len: int,
scale_rope: bool = True, device=None,
) -> torch.Tensor:
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
f_pos = torch.arange(frame, device=device)
img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
# Text positions replicate across all 3 axes (matches original packing).
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
return torch.cat([img_ids, txt_ids], dim=0)
class _TimestepEmbedder(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = F.silu(x)
return self.linear_2(x)
class LensTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
proj = _lens_time_proj(timestep, 256)
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
class GateMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
class LensJointAttention(nn.Module):
"""Joint image+text attention with fused QKV per stream."""
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
dim_head: int = 64,
heads: int = 8,
out_dim: Optional[int] = None,
eps: float = 1e-5,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.heads = self.inner_dim // dim_head
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
# ModuleList([Linear, Identity]) for state-dict key compatibility.
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
nn.Identity(),
])
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_img, _ = hidden_states.shape
seq_txt = encoder_hidden_states.shape[1]
# image stream
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k)
img_v = img_v.contiguous()
del img_qkv
# text stream
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k)
txt_v = txt_v.contiguous()
del txt_qkv
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
raise ValueError(
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
)
attention_mask = attention_mask.to(q.dtype)
out = optimized_attention(
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
transformer_options=transformer_options,
)
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
txt_out = self.to_add_out(out[:, seq_img:, :])
return img_out, txt_out
class LensTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
rms_norm: bool = True,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.attn = LensJointAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
eps=1e-5,
dtype=dtype,
device=device,
operations=operations,
)
if rms_norm:
NormCls = operations.RMSNorm
norm_kwargs = {}
else:
NormCls = operations.LayerNorm
norm_kwargs = {"elementwise_affine": False}
mlp_hidden = int(dim / 3 * 8)
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
@staticmethod
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class _AdaLayerNormContinuousNoAffine(nn.Module):
"""AdaLayerNormContinuous(elementwise_affine=False).
The reference uses ``scale, shift = chunk(2)`` (scale first) opposite
to Flux's ``LastLayer``.
"""
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear = operations.Linear(
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
)
self.eps = eps
self.embedding_dim = embedding_dim
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
emb = self.linear(F.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=-1)
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class LensTransformer2DModel(nn.Module):
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 128,
out_channels: Optional[int] = 32,
num_layers: int = 48,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
enc_hidden_dim: int = 2880,
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
rms_norm: bool = True,
multi_layer_encoder_feature: bool = True,
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
image_model=None, # unused; accepted for detection-side configs.
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.multi_layer_encoder_feature = multi_layer_encoder_feature
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
if self.multi_layer_encoder_feature:
self.txt_norm = nn.ModuleList(
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
for _ in self.selected_layer_index]
)
self.txt_in = operations.Linear(
enc_hidden_dim * len(self.selected_layer_index),
self.inner_dim, bias=True, dtype=dtype, device=device,
)
else:
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
LensTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=1e-6,
rms_norm=rms_norm,
dtype=dtype, device=device, operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = _AdaLayerNormContinuousNoAffine(
self.inner_dim, self.inner_dim, eps=1e-6,
dtype=dtype, device=device, operations=operations,
)
self.proj_out = operations.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
dtype=dtype, device=device,
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
control: Optional[Dict[str, Any]] = None,
**kwargs,
) -> torch.Tensor:
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
if transformer_options is None:
transformer_options = {}
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
B, C, h, w = x.shape
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
if self.multi_layer_encoder_feature:
L = len(self.selected_layer_index)
enc_dim = context.shape[-1] // L
encoder_hidden_states = list(
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
)
text_seq_len = encoder_hidden_states[0].shape[1]
else:
encoder_hidden_states = context
text_seq_len = context.shape[1]
if attention_mask is None:
attention_mask = torch.ones(
(B, text_seq_len), dtype=torch.bool, device=x.device
)
img_len = h * w
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.multi_layer_encoder_feature:
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
encoder_hidden_states = torch.cat(normed, dim=-1)
else:
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)](
{
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": freqs_cis,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
encoder_hidden_states = out["txt"]
hidden_states = out["img"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
attention_mask=joint_mask,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None:
control_i = control.get("input")
if control_i is not None and i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
out = self.proj_out(hidden_states)
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
@staticmethod
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
if text_mask.dtype != torch.bool:
text_mask = text_mask.bool()
bsz = text_mask.shape[0]
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
joint = torch.cat([img_ones, text_mask], dim=1)
additive = torch.zeros_like(joint, dtype=torch.float32)
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
return additive[:, None, None, :]

View File

@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lens.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
@ -1060,6 +1061,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lens(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(
model_config, model_type, device=device,
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
)
def encode_adm(self, **kwargs):
return None # Lens has no pooled/ADM conditioning.
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)

View File

@ -772,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
if multi_layer:
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
# Indices are TE-side; the DiT just consumes L layers in order.
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
else:
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
selected_layer_index = (0,)
return {
"image_model": "lens",
"in_channels": img_in_w.shape[1],
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
"enc_hidden_dim": enc_hidden_dim,
"multi_layer_encoder_feature": multi_layer,
"selected_layer_index": selected_layer_index,
}
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"

View File

@ -18,6 +18,7 @@
import torch
import logging
import contextlib
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
# Quantized-weight module helpers
def _quantized_apply(module, fn, recurse=True):
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
if recurse:
for child in module.children():
child._apply(fn)
for key, param in module._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in module._buffers.items():
if buf is not None:
module._buffers[key] = fn(buf)
return module
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
"""Shared _load_from_state_dict body for quantized-weight modules.
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
and disabled formats from module._disabled_formats.
"""
device = module.factory_kwargs["device"]
compute_dtype = module.factory_kwargs["dtype"]
disabled_formats = module._disabled_formats
layer_name = prefix.rstrip('.')
weight = state_dict.pop(f"{prefix}weight", None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
module.weight = None
return
manually_loaded_keys = [f"{prefix}weight"]
def pop_scale(name, dtype=None):
key = f"{prefix}{name}"
v = state_dict.pop(key, None)
if v is not None:
v = v.to(device=device)
if dtype is not None:
v = v.view(dtype=dtype)
manually_loaded_keys.append(key)
return v
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
else:
module.quant_format = layer_conf.get("format", None)
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not module._full_precision_mm:
module._full_precision_mm = module._full_precision_mm_config
if module.quant_format in disabled_formats:
module._full_precision_mm = True
if module.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[module.quant_format]
module.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(module.layout_type)
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scales = {"scale": pop_scale("weight_scale")}
elif module.quant_format == "mxfp8":
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
scales = {"scale": bs}
elif module.quant_format == "nvfp4":
ts = pop_scale("weight_scale_2")
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
module.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
requires_grad=False,
)
if load_extra_params:
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
extra_quant_params names attributes written as additional top-level keys."""
if not hasattr(module, 'weight'):
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
return sd
bias = getattr(module, 'bias', None)
if bias is not None:
sd[f"{prefix}bias"] = bias
if module.weight is None:
return sd
if isinstance(module.weight, QuantizedTensor):
sd.update(module.weight.state_dict(f"{prefix}weight"))
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
for name in extra_quant_params:
value = getattr(module, name, None)
if value is not None:
sd[f"{prefix}{name}"] = value
else:
sd[f"{prefix}weight"] = module.weight
return sd
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
_disabled_formats = disabled
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (out_features, in_features)
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
return _quantized_apply(self, fn, recurse)
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
"""Container for E quantized expert weights, indexed via expert_weight(i).
The bank lives on self.weight as a single 3D tensor either a
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
with leading expert dim.
State-dict layout matches mixed_precision_ops.Linear with a leading
expert dim:
{prefix}.weight quant data (storage_t), leading dim = E
{prefix}.weight_scale block / per-tensor scale
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
{prefix}.bias [E, out_features] optional, compute_dtype
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
"""
_disabled_formats = disabled
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (num_experts, out_features, in_features)
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
self._resident_bank = None
def reset_parameters(self):
return None
def _apply(self, fn, recurse=True):
return _quantized_apply(self, fn, recurse)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if isinstance(self.weight, QuantizedTensor):
return self._expert_qt_from(self.weight, i)
return self.weight[i]
@contextlib.contextmanager
def bank_resident(self, input):
"""Cast the whole bank once; expert_linear inside reuses the cast.
Not re-entrant do not nest calls on the same instance.
"""
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
self._resident_bank = (weight, bias)
try:
yield self
finally:
self._resident_bank = None
uncast_bias_weight(self, weight, bias, offload_stream)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert i's weight (with optional bias)."""
resident = getattr(self, "_resident_bank", None)
if resident is not None:
weight, bias = resident
return self._expert_linear_impl(input, weight, bias, i)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
try:
return self._expert_linear_impl(input, weight, bias, i)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_linear_impl(self, input, weight, bias, i):
if isinstance(weight, QuantizedTensor):
qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
params = weight._params
kwargs = {
"scale": params.scale[i] if params.scale.dim() else params.scale,
"orig_dtype": params.orig_dtype,
"orig_shape": (self.out_features, self.in_features),
}
if hasattr(params, "block_scale"): # NVFP4
kwargs["block_scale"] = params.block_scale[i]
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
quant_format = layer_conf.get("format") if layer_conf is not None else None
manually_loaded_keys = []
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
manually_loaded_keys.append(weight_key)
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
elif layer_conf is not None:
# Unsupported format — restore the marker so it round-trips; fall through to default load.
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix)
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight

View File

@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.model_patcher
import comfy.lora
@ -1284,6 +1285,7 @@ class CLIPType(Enum):
FLUX2 = 25
LONGCAT_IMAGE = 26
COGVIDEOX = 27
LENS = 28
PIXELDIT = 28
@ -1337,6 +1339,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
def detect_te_model(sd):
@ -1378,6 +1381,9 @@ def detect_te_model(sd):
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
@ -1564,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B:
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")

View File

@ -830,6 +830,48 @@ class Flux2(Flux):
return None
class Lens(supported_models_base.BASE):
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
unet_config = {
"image_model": "lens",
}
sampling_settings = {
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint)
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect),
)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(),
)
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@ -2181,6 +2223,7 @@ models = [
Omnigen2,
QwenImage,
Flux2,
Lens,
Kandinsky5Image,
Kandinsky5,
Anima,

View File

@ -0,0 +1,600 @@
"""GPT-OSS text encoder for Lens."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@dataclass
class GptOss20BConfig:
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
num_hidden_layers: int = 24
num_attention_heads: int = 64
num_key_value_heads: int = 8
head_dim: int = 64
num_local_experts: int = 32
num_experts_per_tok: int = 4
sliding_window: int = 128
original_max_position_embeddings: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32.0
rope_beta_fast: float = 32.0
rope_beta_slow: float = 1.0
rope_truncate: bool = False
rms_norm_eps: float = 1e-5
attention_bias: bool = True
layer_types: Optional[List[str]] = None
moe_alpha: float = 1.702
moe_limit: float = 7.0
def __post_init__(self):
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 2 else "full_attention"
for i in range(self.num_hidden_layers)
]
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim
def find_correction_dim(num_rotations: float) -> float:
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def find_correction_range() -> tuple[float, float]:
low = find_correction_dim(beta_fast)
high = find_correction_dim(beta_slow)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
if min_ == max_:
max_ += 0.001
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
return torch.clamp(linear, 0, 1)
def get_mscale(scale: float) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
attention_scaling = get_mscale(factor)
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range()
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
return inv_freq, attention_scaling
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
pos_e = position_ids[:, None, :].float()
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
sin_split = sin.shape[-1] // 2
return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
"""Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute
nothing to the output. We fake this by appending one zero k/v position and
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
if attention_mask is not None:
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
else:
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
mask = torch.cat([mask_left, sinks_col], dim=-1)
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
class GptOssAttention(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
bias = config.attention_bias
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
B, S, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
return self.o_proj(out)
# Mixture of Experts
class GptOssTopKRouter(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
logits = F.linear(hidden_states, weight, bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx
class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.alpha = config.moe_alpha
self.limit = config.moe_limit
E = self.num_experts
H = self.hidden_size
I = self.intermediate_size
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2]
up = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return torch.addcmul(glu, up, glu)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0]
top_k = router_indices.shape[-1]
H = hidden_states.shape[-1]
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
self.down_proj.bank_resident(hidden_states) as down_bank:
for ei in expert_hit:
expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx]
gate_up = gate_up_bank.expert_linear(current, expert_idx)
gated = self._apply_gate(gate_up)
expert_out = down_bank.expert_linear(gated, expert_idx)
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos
per_pair[flat_idx] = weighted.to(per_pair.dtype)
return per_pair.view(N, top_k, H).sum(dim=1)
class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape
flat = hidden_states.reshape(-1, H)
scores, idx = self.router(flat)
out = self.experts(flat, idx, scores)
return out.reshape(B, S, H)
# Decoder layer + model
class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx]
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
i = torch.arange(S, device=device).view(-1, 1)
j = torch.arange(S, device=device).view(1, -1)
keep = (j <= i) & (j > i - window)
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
class GptOssModel(nn.Module):
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.config = config
self.dtype = dtype
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# Always build on CPU so the buffer survives meta-device construction.
inv_freq, attn_scaling = _yarn_inv_freq(
head_dim=config.head_dim,
base=config.rope_theta,
factor=config.rope_factor,
beta_fast=config.rope_beta_fast,
beta_slow=config.rope_beta_slow,
original_max_position_embeddings=config.original_max_position_embeddings,
truncate=config.rope_truncate,
device=torch.device("cpu"),
)
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
self.rope_attention_scaling = float(attn_scaling)
@property
def num_layers(self) -> int:
return self.config.num_hidden_layers
def get_input_embeddings(self):
return self.embed_tokens
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
) -> dict[str, torch.Tensor]:
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
masks = {"full_attention": full}
if any(t == "sliding_attention" for t in self.config.layer_types):
masks["sliding_attention"] = _make_sliding_causal_mask(
B, S, self.config.sliding_window, attention_mask, dtype, device
)
return masks
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
B, S = input_ids.shape
device = input_ids.device
dtype = self.dtype
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
capture_layers = list(capture_layers) if capture_layers else None
if capture_layers:
max_layer = max(capture_layers)
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
else:
max_layer = self.config.num_hidden_layers - 1
wanted = None
captured = None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
if wanted is not None and i in wanted:
captured[wanted[i]] = hidden_states
if i >= max_layer:
break
if captured is not None:
return {"hidden_states": captured}
return {"last_hidden_state": self.norm(hidden_states)}
# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512
# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"
def _lens_render_chat(prompt: str) -> str:
"""Render the Lens prompt in GPT-OSS Harmony format."""
return (
f"<|start|>system<|message|>"
f"You are ChatGPT, a large language model trained by OpenAI.\n"
f"Knowledge cutoff: 2024-06\n"
f"Current date: {_LENS_CHAT_DATE}\n\n"
f"Reasoning: medium\n\n"
f"# Valid channels: analysis, commentary, final. "
f"Channel must be included for every message.<|end|>"
f"<|start|>developer<|message|># Instructions\n\n"
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
f"<|start|>user<|message|>{prompt}<|end|>"
f"<|start|>assistant<|channel|>analysis<|message|>"
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>"
)
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
class _GptOssRawTokenizer:
"""Raw ``tokenizers.Tokenizer`` wrapper.
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
(``tokenizer_json`` key) rather than as a committed file. Extracted
it in ``sd.py`` and passes it here via ``tokenizer_data``.
"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
if tokenizer_json_bytes is None:
raise ValueError(
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
"embeds the tokenizer."
)
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@classmethod
def from_pretrained(cls, tokenizer_data, **kwargs):
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
def __call__(self, text):
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
def get_vocab(self):
return self.tokenizer.get_vocab()
def convert_tokens_to_ids(self, tokens):
return [self.tokenizer.token_to_id(t) for t in tokens]
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
tokenizer_json_data = None
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(
tokenizer_json,
embedding_directory=embedding_directory,
pad_with_end=False,
embedding_size=2880,
embedding_key="gpt_oss",
tokenizer_class=_GptOssRawTokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_left=False,
disable_weights=True,
tokenizer_data=tokenizer_data,
)
self.pad_token_id = _LENS_PAD_TOKEN_ID
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
if not text or not text.strip():
return [[]]
rendered = _lens_render_chat(text)
ids = self.tokenizer(rendered)["input_ids"]
if len(ids) > LENS_MAX_TOKENS:
ids = ids[:LENS_MAX_TOKENS]
return [[(int(t), 1.0) for t in ids]]
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
class LensTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="gpt_oss",
tokenizer=LensGptOssTokenizer,
)
class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__()
model_options = dict(model_options or {})
operations = model_options.get("custom_operations")
if operations is None:
quant_config = model_options.get("quantization_metadata") or {}
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
self.dtype = dtype
self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
def reset_clip_options(self):
self.execution_device = None
def _gather_tokens(self, token_weight_pairs):
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
pad_id = self._pad_token_id
max_len = max(len(x) for x in ids_list)
device = self.execution_device
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
for i, x in enumerate(ids_list):
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
mask[i, : len(x)] = 1
return ids, mask
def encode_token_weights(self, token_weight_pairs):
# Empty negative: emit zero-length features + zero mask
if all(len(batch) == 0 for batch in token_weight_pairs):
device = self.execution_device
B = len(token_weight_pairs)
L = len(self.selected_layers)
H = self.config.hidden_size
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
layers = out["hidden_states"] # list of L × [B, S, H]
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
offset = self.txt_offset
if stacked.shape[1] > offset:
stacked = stacked[:, offset:].contiguous()
mask_trim = attn_mask[:, offset:]
else:
stacked = stacked[:, :0]
mask_trim = attn_mask[:, :0]
B, S, L, H = stacked.shape
flat = stacked.reshape(B, S, L * H)
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
return flat, None, extra
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=True)
class LensTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {})
if llama_quantization_metadata is not None:
mo["quantization_metadata"] = llama_quantization_metadata
if dtype is None and dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_

View File

@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
"pre_cfg",
default=False,
optional=True,
tooltip=(
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
"without clamping (can amplify). Matches the norm-scaled CFG used by "
"models like Lens. Default false keeps the original post-CFG x0-space "
"attenuate-only behavior."
),
),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
if pre_cfg:
def cfg_norm_pre(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
rescale = torch.where(
comb_norm > 0,
cond_norm / comb_norm.clamp_min(1e-12),
torch.ones_like(comb_norm),
)
rescaled = comb * rescale
# strength blends back toward standard linear CFG (1.0 = full rescale).
if strength != 1.0:
rescaled = strength * rescaled + (1.0 - strength) * comb
return rescaled
m.set_model_sampler_cfg_function(cfg_norm_pre)
else:
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)

View File

@ -969,7 +969,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "pixeldit"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -979,7 +979,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)