diff --git a/comfy/ldm/lens/model.py b/comfy/ldm/lens/model.py new file mode 100644 index 000000000..7bff7f6af --- /dev/null +++ b/comfy/ldm/lens/model.py @@ -0,0 +1,513 @@ +"""Lens denoising transformer (DiT)""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.flux.layers +import comfy.patcher_extension +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention + + +def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor: + return comfy.ldm.flux.layers.timestep_embedding(t, dim) + + +def _lens_position_ids( + frame: int, height: int, width: int, text_seq_len: int, + scale_rope: bool = True, device=None, +) -> torch.Tensor: + """Lens axial (frame, h, w) position ids for joint image + text sequence. + + With ``scale_rope=True`` h/w are centered around 0 (negative + positive + halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``; + caller adds a batch dim for ``EmbedND``. + """ + if scale_rope: + h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device), + torch.arange(0, height // 2, device=device)]) + w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device), + torch.arange(0, width // 2, device=device)]) + text_start = max(height // 2, width // 2) + else: + h_pos = torch.arange(height, device=device) + w_pos = torch.arange(width, device=device) + text_start = max(height, width) + + f_pos = torch.arange(frame, device=device) + img_ids = torch.zeros(frame, height, width, 3, device=device) + img_ids[..., 0] = f_pos[:, None, None] + img_ids[..., 1] = h_pos[None, :, None] + img_ids[..., 2] = w_pos[None, None, :] + img_ids = img_ids.reshape(-1, 3) + + # Text positions replicate across all 3 axes (matches original packing). + txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float() + txt_ids = txt_pos[:, None].expand(text_seq_len, 3) + + return torch.cat([img_ids, txt_ids], dim=0) + + +class _TimestepEmbedder(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = F.silu(x) + return self.linear_2(x) + + +class LensTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations) + + def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + proj = _lens_time_proj(timestep, 256) + return self.timestep_embedder(proj.to(dtype=hidden_states.dtype)) + + +class GateMLP(nn.Module): + """SwiGLU MLP.""" + + def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x))) + + +class LensJointAttention(nn.Module): + """Joint image+text attention with fused QKV per stream.""" + + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + dim_head: int = 64, + heads: int = 8, + out_dim: Optional[int] = None, + eps: float = 1e-5, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = self.inner_dim // dim_head + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + + self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + + # ModuleList([Linear, Identity]) for state-dict key compatibility. + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device), + nn.Identity(), + ]) + self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, seq_img, _ = hidden_states.shape + seq_txt = encoder_hidden_states.shape[1] + + # image stream + img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head) + img_q, img_k, img_v = img_qkv.unbind(dim=2) + img_q = self.norm_q(img_q) + img_k = self.norm_k(img_k) + img_v = img_v.contiguous() + del img_qkv + + # text stream + txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head) + txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2) + txt_q = self.norm_added_q(txt_q) + txt_k = self.norm_added_k(txt_k) + txt_v = txt_v.contiguous() + del txt_qkv + + # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks + q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) + del img_q, txt_q + k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2) + del img_k, txt_k + v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) + del img_v, txt_v + + q, k = apply_rope(q, k, freqs_cis) + + if attention_mask is not None: + expected = (bsz, 1, 1, seq_img + seq_txt) + if attention_mask.shape != expected: + raise ValueError( + f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}" + ) + attention_mask = attention_mask.to(q.dtype) + + out = optimized_attention( + q, k, v, self.heads, mask=attention_mask, skip_reshape=True, + transformer_options=transformer_options, + ) + + img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :])) + txt_out = self.to_add_out(out[:, seq_img:, :]) + return img_out, txt_out + + +class LensTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + rms_norm: bool = True, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + + self.attn = LensJointAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + eps=1e-5, + dtype=dtype, + device=device, + operations=operations, + ) + + if rms_norm: + NormCls = operations.RMSNorm + norm_kwargs = {} + else: + NormCls = operations.LayerNorm + norm_kwargs = {"elementwise_affine": False} + + mlp_hidden = int(dim / 3 * 8) + + # Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}. + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + @staticmethod + def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1) + txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1) + + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + transformer_options=transformer_options, + ) + + hidden_states = hidden_states + img_gate1 * img_attn + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + return encoder_hidden_states, hidden_states + + +class _AdaLayerNormContinuousNoAffine(nn.Module): + """AdaLayerNormContinuous(elementwise_affine=False). + + The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite + to Flux's ``LastLayer``. + """ + + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6, + dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear = operations.Linear( + conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device + ) + self.eps = eps + self.embedding_dim = embedding_dim + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + emb = self.linear(F.silu(conditioning)) + scale, shift = torch.chunk(emb, 2, dim=-1) + x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class LensTransformer2DModel(nn.Module): + """Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text).""" + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 128, + out_channels: Optional[int] = 32, + num_layers: int = 48, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + enc_hidden_dim: int = 2880, + axes_dims_rope: Tuple[int, int, int] = (8, 28, 28), + rms_norm: bool = True, + multi_layer_encoder_feature: bool = True, + selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23), + image_model=None, # unused; accepted for detection-side configs. + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels if out_channels is not None else in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.multi_layer_encoder_feature = multi_layer_encoder_feature + self.selected_layer_index = list(selected_layer_index) + self.dtype = dtype + + self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + self.time_text_embed = LensTimestepProjEmbeddings( + embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations + ) + + if self.multi_layer_encoder_feature: + self.txt_norm = nn.ModuleList( + [operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + for _ in self.selected_layer_index] + ) + self.txt_in = operations.Linear( + enc_hidden_dim * len(self.selected_layer_index), + self.inner_dim, bias=True, dtype=dtype, device=device, + ) + else: + self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + LensTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=1e-6, + rms_norm=rms_norm, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = _AdaLayerNormContinuousNoAffine( + self.inner_dim, self.inner_dim, eps=1e-6, + dtype=dtype, device=device, operations=operations, + ) + self.proj_out = operations.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, + dtype=dtype, device=device, + ) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor: + if transformer_options is None: + transformer_options = {} + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timestep, context, attention_mask, transformer_options, **kwargs) + + def _forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + control: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> torch.Tensor: + """ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``.""" + if transformer_options is None: + transformer_options = {} + transformer_options = transformer_options.copy() + patches = transformer_options.get("patches", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + B, C, h, w = x.shape + hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C) + + if self.multi_layer_encoder_feature: + L = len(self.selected_layer_index) + enc_dim = context.shape[-1] // L + encoder_hidden_states = list( + context.reshape(B, -1, L, enc_dim).unbind(dim=2) + ) + text_seq_len = encoder_hidden_states[0].shape[1] + else: + encoder_hidden_states = context + text_seq_len = context.shape[1] + + if attention_mask is None: + attention_mask = torch.ones( + (B, text_seq_len), dtype=torch.bool, device=x.device + ) + + img_len = h * w + joint_mask = self._build_joint_attention_mask(attention_mask, img_len) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + if self.multi_layer_encoder_feature: + normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)] + encoder_hidden_states = torch.cat(normed, dim=-1) + else: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if "post_input" in patches: + for p in patches["post_input"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + temb = self.time_text_embed(timestep, hidden_states) + ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0) + freqs_cis = self.pos_embed(ids) + + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + temb=args["vec"], + freqs_cis=args["pe"], + attention_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options"), + ) + return out + out = blocks_replace[("double_block", i)]( + { + "img": hidden_states, + "txt": encoder_hidden_states, + "vec": temb, + "pe": freqs_cis, + "attn_mask": joint_mask, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + encoder_hidden_states = out["txt"] + hidden_states = out["img"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + attention_mask=joint_mask, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "x": x, + "block_index": i, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + if control is not None: + control_i = control.get("input") + if control_i is not None and i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add + + hidden_states = self.norm_out(hidden_states, temb) + out = self.proj_out(hidden_states) + return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous() + + @staticmethod + def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor: + if text_mask.dtype != torch.bool: + text_mask = text_mask.bool() + bsz = text_mask.shape[0] + img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device) + joint = torch.cat([img_ones, text_mask], dim=1) + additive = torch.zeros_like(joint, dtype=torch.float32) + additive.masked_fill_(~joint, torch.finfo(torch.float32).min) + return additive[:, None, None, :] diff --git a/comfy/model_base.py b/comfy/model_base.py index d81f13c69..d4ab1499e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -35,6 +35,7 @@ import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders import comfy.ldm.flux.model +import comfy.ldm.lens.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model @@ -1058,6 +1059,27 @@ class Flux2(Flux): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class Lens(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__( + model_config, model_type, device=device, + unet_model=comfy.ldm.lens.model.LensTransformer2DModel, + ) + + def encode_adm(self, **kwargs): + return None # Lens has no pooled/ADM conditioning. + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + return out + class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 70b4df8b3..2b0b98cd8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -755,6 +755,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \ + and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens + img_in_w = state_dict['{}img_in.weight'.format(key_prefix)] + proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)] + multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys + if multi_layer: + enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0] + # Indices are TE-side; the DiT just consumes L layers in order. + selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.'))) + else: + enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0] + selected_layer_index = (0,) + + return { + "image_model": "lens", + "in_channels": img_in_w.shape[1], + "out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default) + "num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'), + "num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default + "enc_hidden_dim": enc_hidden_dim, + "multi_layer_encoder_feature": multi_layer, + "selected_layer_index": selected_layer_index, + } + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" diff --git a/comfy/ops.py b/comfy/ops.py index 9bcd6c900..56445be8d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,6 +18,7 @@ import torch import logging +import contextlib import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +# Quantized-weight module helpers + +def _quantized_apply(module, fn, recurse=True): + """Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights.""" + if recurse: + for child in module.children(): + child._apply(fn) + for key, param in module._parameters.items(): + if param is None: + continue + p = fn(param) + if (not torch.is_inference_mode_enabled()) and p.is_inference(): + p = p.clone() + module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) + for key, buf in module._buffers.items(): + if buf is not None: + module._buffers[key] = fn(buf) + return module + + +def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, load_extra_params=False): + """Shared _load_from_state_dict body for quantized-weight modules. + + Pops weight (+ scales, +/- extras), populates module.weight as a Parameter + or Parameter-wrapped QuantizedTensor, then calls super_load and strips + consumed keys from missing_keys. Reads compute_dtype from factory_kwargs + and disabled formats from module._disabled_formats. + """ + device = module.factory_kwargs["device"] + compute_dtype = module.factory_kwargs["dtype"] + disabled_formats = module._disabled_formats + layer_name = prefix.rstrip('.') + + weight = state_dict.pop(f"{prefix}weight", None) + if weight is None: + logging.warning(f"Missing weight for layer {layer_name}") + module.weight = None + return + manually_loaded_keys = [f"{prefix}weight"] + + def pop_scale(name, dtype=None): + key = f"{prefix}{name}" + v = state_dict.pop(key, None) + if v is not None: + v = v.to(device=device) + if dtype is not None: + v = v.view(dtype=dtype) + manually_loaded_keys.append(key) + return v + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False) + else: + module.quant_format = layer_conf.get("format", None) + module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) + if not module._full_precision_mm: + module._full_precision_mm = module._full_precision_mm_config + if module.quant_format in disabled_formats: + module._full_precision_mm = True + if module.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[module.quant_format] + module.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(module.layout_type) + + # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8. + if module.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scales = {"scale": pop_scale("weight_scale")} + elif module.quant_format == "mxfp8": + bs = pop_scale("weight_scale", torch.float8_e8m0fnu) + if bs is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + scales = {"scale": bs} + elif module.quant_format == "nvfp4": + ts = pop_scale("weight_scale_2") + bs = pop_scale("weight_scale", torch.float8_e4m3fn) + if ts is None or bs is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + scales = {"scale": ts, "block_scale": bs} + else: + raise ValueError(f"Unsupported quantization format: {module.quant_format}") + + params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape) + module.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params), + requires_grad=False, + ) + + if load_extra_params: + for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + +def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()): + """Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON; + extra_quant_params names attributes written as additional top-level keys.""" + if not hasattr(module, 'weight'): + logging.warning(f"Warning: state dict on uninitialized op {prefix}") + return sd + bias = getattr(module, 'bias', None) + if bias is not None: + sd[f"{prefix}bias"] = bias + if module.weight is None: + return sd + if isinstance(module.weight, QuantizedTensor): + sd.update(module.weight.state_dict(f"{prefix}weight")) + quant_conf = {"format": module.quant_format} + if getattr(module, '_full_precision_mm_config', False): + quant_conf["full_precision_matrix_mult"] = True + if extra_quant_conf: + quant_conf.update(extra_quant_conf) + sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) + for name in extra_quant_params: + value = getattr(module, name, None) + if value is not None: + sd[f"{prefix}{name}"] = value + else: + sd[f"{prefix}weight"] = module.weight + return sd + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): @@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: + _disabled_formats = disabled + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): super().__init__() self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features + self._orig_shape = (out_features, in_features) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) else: @@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None - def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): - key = f"{prefix}{param_name}" - value = state_dict.pop(key, None) - if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) - manually_loaded_keys.append(key) - return value - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): - - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - logging.warning(f"Missing weight for layer {layer_name}") - self.weight = None - return - - manually_loaded_keys = [weight_key] - - layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: - layer_conf = json.loads(layer_conf.numpy().tobytes()) - - if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - self.quant_format = layer_conf.get("format", None) - self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) - if not self._full_precision_mm: - self._full_precision_mm = self._full_precision_mm_config - - if self.quant_format in MixedPrecisionOps._disabled: - self._full_precision_mm = True - - if self.quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") - - qconfig = QUANT_ALGOS[self.quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = get_layout_class(self.layout_type) - - # Load format-specific parameters - if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: - # FP8: single tensor scale - scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "mxfp8": - # MXFP8: E8M0 block scales stored as uint8 in safetensors - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.uint8) - - if block_scale is None: - raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") - - block_scale = block_scale.view(torch.float8_e8m0fnu) - - params = layout_cls.Params( - scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "nvfp4": - # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) - tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.float8_e4m3fn) - - if tensor_scale is None or block_scale is None: - raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") - - params = layout_cls.Params( - scale=tensor_scale, - block_scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - else: - raise ValueError(f"Unsupported quantization format: {self.quant_format}") - - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) - - for param_name in qconfig["parameters"]: - if param_name in {"weight_scale", "weight_scale_2"}: - continue # Already handled above - - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight'): - logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) - return sd - - if self.bias is not None: - sd["{}bias".format(prefix)] = self.bias - - if self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - - input_scale = getattr(self, 'input_scale', None) - if input_scale is not None: - sd["{}input_scale".format(prefix)] = input_scale - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",)) def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight, requires_grad=False) def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working - if recurse: - for module in self.children(): - module._apply(fn) + return _quantized_apply(self, fn, recurse) - for key, param in self._parameters.items(): - if param is None: - continue - p = fn(param) - if (not torch.is_inference_mode_enabled()) and p.is_inference(): - p = p.clone() - self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) - for key, buf in self._buffers.items(): - if buf is not None: - self._buffers[key] = fn(buf) - return self + class MoEExperts(torch.nn.Module, CastWeightBiasOp): + """Container for E quantized expert weights, indexed via expert_weight(i). + + The bank lives on self.weight as a single 3D tensor — either a + compute_dtype Parameter or a Parameter wrapping a QuantizedTensor + with leading expert dim. + + State-dict layout matches mixed_precision_ops.Linear with a leading + expert dim: + {prefix}.weight quant data (storage_t), leading dim = E + {prefix}.weight_scale block / per-tensor scale + {prefix}.weight_scale_2 [E] or scalar NVFP4 only + {prefix}.bias [E, out_features] optional, compute_dtype + {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}} + + Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in]. + """ + + _disabled_formats = disabled + + def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + super().__init__() + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self._orig_shape = (num_experts, out_features, in_features) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + if bias: + self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + # Populated by _load_from_state_dict: + self.weight = None + self.quant_format = None + self.layout_type = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False + self._resident_bank = None + + def reset_parameters(self): + return None + + def _apply(self, fn, recurse=True): + return _quantized_apply(self, fn, recurse) + + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False) + + def expert_weight(self, i: int): + """Expert i's weight (Tensor or per-expert QuantizedTensor view).""" + if isinstance(self.weight, QuantizedTensor): + return self._expert_qt_from(self.weight, i) + return self.weight[i] + + @contextlib.contextmanager + def bank_resident(self, input): + """Cast the whole bank once; expert_linear inside reuses the cast. + Not re-entrant — do not nest calls on the same instance. + """ + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + self._resident_bank = (weight, bias) + try: + yield self + finally: + self._resident_bank = None + uncast_bias_weight(self, weight, bias, offload_stream) + + def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: + """Linear against expert i's weight (with optional bias).""" + resident = getattr(self, "_resident_bank", None) + if resident is not None: + weight, bias = resident + return self._expert_linear_impl(input, weight, bias, i) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + try: + return self._expert_linear_impl(input, weight, bias, i) + finally: + uncast_bias_weight(self, weight, bias, offload_stream) + + def _expert_linear_impl(self, input, weight, bias, i): + if isinstance(weight, QuantizedTensor): + qw = self._expert_qt_from(weight, i) + else: + qw = weight[i] + b = cast_to_input(bias[i], input, copy=False) if bias is not None else None + + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, b) + out = input @ qw.dequantize().t() + return out + b if b is not None else out + return torch.nn.functional.linear(input, qw, b) + + def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor: + """Build a per-expert QuantizedTensor by indexing into a resident bank.""" + params = weight._params + kwargs = { + "scale": params.scale[i] if params.scale.dim() else params.scale, + "orig_dtype": params.orig_dtype, + "orig_shape": (self.out_features, self.in_features), + } + if hasattr(params, "block_scale"): # NVFP4 + kwargs["block_scale"] = params.block_scale[i] + return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs)) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts}) class Embedding(manual_cast.Embedding): - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight_key = f"{prefix}weight" layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) if layer_conf is not None: @@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Only fp8 makes sense for embeddings (per-row dequant via index select). # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. - quant_format = layer_conf.get("format", None) if layer_conf is not None else None - if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + quant_format = layer_conf.get("format") if layer_conf is not None else None + manually_loaded_keys = [] + + if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict: self.quant_format = quant_format qconfig = QUANT_ALGOS[quant_format] self.layout_type = qconfig["comfy_tensor_layout"] layout_cls = get_layout_class(self.layout_type) weight = state_dict.pop(weight_key) - manually_loaded_keys = [weight_key] + manually_loaded_keys.append(weight_key) scale_key = f"{prefix}weight_scale" scale = state_dict.pop(scale_key, None) @@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter( QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), requires_grad=False) + elif layer_conf is not None: + # Unsupported format — restore the marker so it round-trips; fall through to default load. + state_dict[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for k in manually_loaded_keys: - if k in missing_keys: - missing_keys.remove(k) - else: - if layer_conf is not None: - state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight') or self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix) def forward_comfy_cast_weights(self, input, out_dtype=None): weight = self.weight diff --git a/comfy/sd.py b/comfy/sd.py index a4e49763a..beb782310 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -68,6 +68,7 @@ import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 +import comfy.text_encoders.gpt_oss import comfy.model_patcher import comfy.lora @@ -1283,6 +1284,7 @@ class CLIPType(Enum): FLUX2 = 25 LONGCAT_IMAGE = 26 COGVIDEOX = 27 + LENS = 28 @@ -1335,6 +1337,7 @@ class TEModel(Enum): GEMMA_4_E2B = 30 GEMMA_4_31B = 31 T5_GEMMA = 32 + GPT_OSS_20B = 33 def detect_te_model(sd): @@ -1376,6 +1379,9 @@ def detect_te_model(sd): else: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B + # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512). + if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd: + return TEModel.GPT_OSS_20B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] if weight.shape[0] == 256: @@ -1558,6 +1564,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.GPT_OSS_20B: + clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.QWEN3_4B: if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 617db4f28..e451892e9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -829,6 +829,48 @@ class Flux2(Flux): return None + +class Lens(supported_models_base.BASE): + """Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE).""" + + unet_config = { + "image_model": "lens", + } + + sampling_settings = { + "shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300 + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device) + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + for hint in ("gpt_oss.transformer.", ""): + full_prefix = "{}{}".format(pref, hint) + if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict: + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(**detect), + ) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(), + ) + + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -2096,6 +2138,7 @@ models = [ Omnigen2, QwenImage, Flux2, + Lens, Kandinsky5Image, Kandinsky5, Anima, diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py new file mode 100644 index 000000000..d596ef9a0 --- /dev/null +++ b/comfy/text_encoders/gpt_oss.py @@ -0,0 +1,600 @@ +"""GPT-OSS text encoder for Lens.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy import sd1_clip +from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device +from comfy.text_encoders.llama import RMSNorm, apply_rope + + +@dataclass +class GptOss20BConfig: + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + num_hidden_layers: int = 24 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + head_dim: int = 64 + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + sliding_window: int = 128 + original_max_position_embeddings: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32.0 + rope_beta_fast: float = 32.0 + rope_beta_slow: float = 1.0 + rope_truncate: bool = False + rms_norm_eps: float = 1e-5 + attention_bias: bool = True + layer_types: Optional[List[str]] = None + moe_alpha: float = 1.702 + moe_limit: float = 7.0 + + def __post_init__(self): + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 2 else "full_attention" + for i in range(self.num_hidden_layers) + ] + + +def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float, + original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]: + """YARN inv_freq + attention scaling (matches transformers).""" + dim = head_dim + + def find_correction_dim(num_rotations: float) -> float: + return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def find_correction_range() -> tuple[float, float]: + low = find_correction_dim(beta_fast) + high = find_correction_dim(beta_slow) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor: + if min_ == max_: + max_ += 0.001 + linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_) + return torch.clamp(linear, 0, 1) + + def get_mscale(scale: float) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + attention_scaling = get_mscale(factor) + + pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range() + extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2) + inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor + return inv_freq, attention_scaling + + +def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + pos_e = position_ids[:, None, :].float() + freqs = (inv_freq_e @ pos_e).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1) + sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1) + sin_split = sin.shape[-1] // 2 + return cos, sin[..., :sin_split], -sin[..., sin_split:] + + +def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor, + attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor: + """Attention with per-head sinks. + + Sinks add a learned term to each row's softmax denominator but contribute + nothing to the output. We fake this by appending one zero k/v position and + putting the sink logit in the mask at that column. + """ + + if num_kv_groups > 1 and not TORCH_HAS_GQA: + k = k.repeat_interleave(num_kv_groups, dim=1) + v = v.repeat_interleave(num_kv_groups, dim=1) + + B, _, S_q, D = q.shape + H_kv = k.shape[1] + S_kv = k.shape[-2] + + k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2) + v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2) + + sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1) + if attention_mask is not None: + mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv) + else: + mask_left = q.new_zeros(B, num_heads, S_q, S_kv) + mask = torch.cat([mask_left, sinks_col], dim=-1) + + op = optimized_attention_for_device(q.device, mask=True, small_input=True) + return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True) + + +class GptOssAttention(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + bias = config.attention_bias + self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype) + self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor: + B, S, _ = hidden_states.shape + + q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q, k = apply_rope(q, k, freqs_cis) + + out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups) + return self.o_proj(out) + + +# Mixture of Experts + +class GptOssTopKRouter(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False) + bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False) + logits = F.linear(hidden_states, weight, bias) + top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1) + # Softmax over top-k slice only + scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype) + return scores, top_idx + + +class GptOssExperts(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.alpha = config.moe_alpha + self.limit = config.moe_limit + + E = self.num_experts + H = self.hidden_size + I = self.intermediate_size + + self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype) + self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype) + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate = gate_up[..., ::2] + up = gate_up[..., 1::2] + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + return torch.addcmul(glu, up, glu) + + def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + N = hidden_states.shape[0] + top_k = router_indices.shape[-1] + H = hidden_states.shape[-1] + + per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device) + + expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \ + self.down_proj.bank_resident(hidden_states) as down_bank: + for ei in expert_hit: + expert_idx = int(ei.item()) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current = hidden_states[token_idx] + + gate_up = gate_up_bank.expert_linear(current, expert_idx) + gated = self._apply_gate(gate_up) + expert_out = down_bank.expert_linear(gated, expert_idx) + + weighted = expert_out * routing_weights[token_idx, top_k_pos, None] + + flat_idx = token_idx * top_k + top_k_pos + per_pair[flat_idx] = weighted.to(per_pair.dtype) + + return per_pair.view(N, top_k, H).sum(dim=1) + + +class GptOssMLP(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.router = GptOssTopKRouter(config, device=device, dtype=dtype) + self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + flat = hidden_states.reshape(-1, H) + scores, idx = self.router(flat) + out = self.experts(flat, idx, scores) + return out.reshape(B, S, H) + + +# Decoder layer + model + +class GptOssDecoderLayer(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops) + self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.layer_type = config.layer_types[layer_idx] + + def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor: + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + i = torch.arange(S, device=device).view(-1, 1) + j = torch.arange(S, device=device).view(1, -1) + keep = (j <= i) & (j > i - window) + mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device)) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +class GptOssModel(nn.Module): + """GPT-OSS decoder with multi-layer hidden-state capture + early exit.""" + + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.config = config + self.dtype = dtype + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.layers = nn.ModuleList( + [ + GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + # Always build on CPU so the buffer survives meta-device construction. + inv_freq, attn_scaling = _yarn_inv_freq( + head_dim=config.head_dim, + base=config.rope_theta, + factor=config.rope_factor, + beta_fast=config.rope_beta_fast, + beta_slow=config.rope_beta_slow, + original_max_position_embeddings=config.original_max_position_embeddings, + truncate=config.rope_truncate, + device=torch.device("cpu"), + ) + self.register_buffer("rope_inv_freq", inv_freq, persistent=False) + self.rope_attention_scaling = float(attn_scaling) + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.embed_tokens + + def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device, + ) -> dict[str, torch.Tensor]: + full = _make_full_causal_mask(B, S, attention_mask, dtype, device) + masks = {"full_attention": full} + if any(t == "sliding_attention" for t in self.config.layer_types): + masks["sliding_attention"] = _make_sliding_causal_mask( + B, S, self.config.sliding_window, attention_mask, dtype, device + ) + return masks + + def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, + capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]: + B, S = input_ids.shape + device = input_ids.device + dtype = self.dtype + + hidden_states = self.embed_tokens(input_ids, out_dtype=dtype) + + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype) + + attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device) + + capture_layers = list(capture_layers) if capture_layers else None + if capture_layers: + max_layer = max(capture_layers) + wanted = {idx: pos for pos, idx in enumerate(capture_layers)} + captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers) + else: + max_layer = self.config.num_hidden_layers - 1 + wanted = None + captured = None + + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, attn_masks, freqs_cis) + if wanted is not None and i in wanted: + captured[wanted[i]] = hidden_states + if i >= max_layer: + break + + if captured is not None: + return {"hidden_states": captured} + return {"last_hidden_state": self.norm(hidden_states)} + + +# Lens chat-template constants (verbatim from the reference pipeline). +_LENS_CHAT_SYSTEM = ( + "Describe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background." +) +_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description." +LENS_TXT_OFFSET = 97 +LENS_SELECTED_LAYERS = (5, 11, 17, 23) +LENS_MAX_TOKENS = 512 + + +# The reference GPT-OSS Harmony template injects today's date here +_LENS_CHAT_DATE = "2026-05-23" + + +def _lens_render_chat(prompt: str) -> str: + """Render the Lens prompt in GPT-OSS Harmony format.""" + return ( + f"<|start|>system<|message|>" + f"You are ChatGPT, a large language model trained by OpenAI.\n" + f"Knowledge cutoff: 2024-06\n" + f"Current date: {_LENS_CHAT_DATE}\n\n" + f"Reasoning: medium\n\n" + f"# Valid channels: analysis, commentary, final. " + f"Channel must be included for every message.<|end|>" + f"<|start|>developer<|message|># Instructions\n\n" + f"{_LENS_CHAT_SYSTEM}\n\n<|end|>" + f"<|start|>user<|message|>{prompt}<|end|>" + f"<|start|>assistant<|channel|>analysis<|message|>" + f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>" + f"<|start|>assistant<|channel|>final<|message|>" + ) + + +# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table). +_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|> + + +class _GptOssRawTokenizer: + """Raw ``tokenizers.Tokenizer`` wrapper. + + The tokenizer JSON ships as a byte tensor inside the encoder checkpoint + (``tokenizer_json`` key) rather than as a committed file. Extracted + it in ``sd.py`` and passes it here via ``tokenizer_data``. + """ + + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + if tokenizer_json_bytes is None: + raise ValueError( + "Lens tokenizer requires the ``tokenizer_json`` byte tensor in the " + "encoder state dict. Re-bundle the encoder via bundle_te.py so it " + "embeds the tokenizer." + ) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +class LensGptOssTokenizer(sd1_clip.SDTokenizer): + tokenizer_json_data = None + + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__( + tokenizer_json, + embedding_directory=embedding_directory, + pad_with_end=False, + embedding_size=2880, + embedding_key="gpt_oss", + tokenizer_class=_GptOssRawTokenizer, + has_start_token=False, + has_end_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=1, + pad_left=False, + disable_weights=True, + tokenizer_data=tokenizer_data, + ) + self.pad_token_id = _LENS_PAD_TOKEN_ID + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + # Empty prompt -> empty list; encode_token_weights returns zeros (uncond). + if not text or not text.strip(): + return [[]] + rendered = _lens_render_chat(text) + ids = self.tokenizer(rendered)["input_ids"] + if len(ids) > LENS_MAX_TOKENS: + ids = ids[:LENS_MAX_TOKENS] + return [[(int(t), 1.0) for t in ids]] + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + +class LensTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="gpt_oss", + tokenizer=LensGptOssTokenizer, + ) + + +class LensGptOssClipModel(nn.Module): + """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor).""" + + def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs): + super().__init__() + model_options = dict(model_options or {}) + + operations = model_options.get("custom_operations") + if operations is None: + quant_config = model_options.get("quantization_metadata") or {} + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + self.operations = operations + + cfg_overrides = model_options.get("gpt_oss_config", {}) + self.config = GptOss20BConfig(**cfg_overrides) + self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS)) + self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET)) + + self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + self.dtype = dtype + self.execution_device = None + self._pad_token_id = _LENS_PAD_TOKEN_ID + + def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) + + def reset_clip_options(self): + self.execution_device = None + + def _gather_tokens(self, token_weight_pairs): + ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs] + pad_id = self._pad_token_id + max_len = max(len(x) for x in ids_list) + device = self.execution_device + ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device) + mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device) + for i, x in enumerate(ids_list): + ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device) + mask[i, : len(x)] = 1 + return ids, mask + + def encode_token_weights(self, token_weight_pairs): + # Empty negative: emit zero-length features + zero mask + if all(len(batch) == 0 for batch in token_weight_pairs): + device = self.execution_device + B = len(token_weight_pairs) + L = len(self.selected_layers) + H = self.config.hidden_size + flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device) + mask = torch.zeros(B, 0, dtype=torch.long, device=device) + return flat, None, {"attention_mask": mask, "num_layers_stacked": L} + + input_ids, attn_mask = self._gather_tokens(token_weight_pairs) + out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers) + layers = out["hidden_states"] # list of L × [B, S, H] + stacked = torch.stack(layers, dim=2) # [B, S, L, H] + + offset = self.txt_offset + if stacked.shape[1] > offset: + stacked = stacked[:, offset:].contiguous() + mask_trim = attn_mask[:, offset:] + else: + stacked = stacked[:, :0] + mask_trim = attn_mask[:, :0] + + B, S, L, H = stacked.shape + flat = stacked.reshape(B, S, L * H) + extra = {"attention_mask": mask_trim, "num_layers_stacked": L} + return flat, None, extra + + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False, assign=True) + + +class LensTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {}) + + +def lens_te(dtype_llama=None, llama_quantization_metadata=None): + class LensTEModel_(LensTEModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + mo = dict(model_options or {}) + if llama_quantization_metadata is not None: + mo["quantization_metadata"] = llama_quantization_metadata + if dtype is None and dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=mo) + + return LensTEModel_ diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 4ebb4b51e..b585c560f 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + "pre_cfg", + default=False, + optional=True, + tooltip=( + "If true, rescale the combined noise BEFORE the sampler's CFG combine, " + "without clamping (can amplify). Matches the norm-scaled CFG used by " + "models like Lens. Default false keeps the original post-CFG x0-space " + "attenuate-only behavior." + ), + ), ], outputs=[io.Model.Output(display_name="patched_model")], is_experimental=True, ) @classmethod - def execute(cls, model, strength) -> io.NodeOutput: + def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput: m = model.clone() - def cfg_norm(args): - cond_p = args['cond_denoised'] - pred_text_ = args["denoised"] + if pre_cfg: + def cfg_norm_pre(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True) + comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True) + rescale = torch.where( + comb_norm > 0, + cond_norm / comb_norm.clamp_min(1e-12), + torch.ones_like(comb_norm), + ) + rescaled = comb * rescale + # strength blends back toward standard linear CFG (1.0 = full rescale). + if strength != 1.0: + rescaled = strength * rescaled + (1.0 - strength) * comb + return rescaled + m.set_model_sampler_cfg_function(cfg_norm_pre) + else: + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] - norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) - norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) - scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) - return pred_text_ * scale * strength + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength - m.set_model_sampler_post_cfg_function(cfg_norm) + m.set_model_sampler_post_cfg_function(cfg_norm) return io.NodeOutput(m) diff --git a/nodes.py b/nodes.py index 669a7057b..13d3864cd 100644 --- a/nodes.py +++ b/nodes.py @@ -969,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -979,7 +979,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)