mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
601 lines
25 KiB
Python
601 lines
25 KiB
Python
"""GPT-OSS text encoder for Lens."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
from dataclasses import dataclass
|
||
from typing import Any, List, Optional, Sequence
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
import comfy.ops
|
||
from comfy import sd1_clip
|
||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||
|
||
|
||
@dataclass
|
||
class GptOss20BConfig:
|
||
vocab_size: int = 201088
|
||
hidden_size: int = 2880
|
||
intermediate_size: int = 2880
|
||
num_hidden_layers: int = 24
|
||
num_attention_heads: int = 64
|
||
num_key_value_heads: int = 8
|
||
head_dim: int = 64
|
||
num_local_experts: int = 32
|
||
num_experts_per_tok: int = 4
|
||
sliding_window: int = 128
|
||
original_max_position_embeddings: int = 4096
|
||
rope_theta: float = 150000.0
|
||
rope_factor: float = 32.0
|
||
rope_beta_fast: float = 32.0
|
||
rope_beta_slow: float = 1.0
|
||
rope_truncate: bool = False
|
||
rms_norm_eps: float = 1e-5
|
||
attention_bias: bool = True
|
||
layer_types: Optional[List[str]] = None
|
||
moe_alpha: float = 1.702
|
||
moe_limit: float = 7.0
|
||
|
||
def __post_init__(self):
|
||
if self.layer_types is None:
|
||
self.layer_types = [
|
||
"sliding_attention" if (i + 1) % 2 else "full_attention"
|
||
for i in range(self.num_hidden_layers)
|
||
]
|
||
|
||
|
||
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
|
||
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
|
||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||
dim = head_dim
|
||
|
||
def find_correction_dim(num_rotations: float) -> float:
|
||
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||
2 * math.log(base)
|
||
)
|
||
|
||
def find_correction_range() -> tuple[float, float]:
|
||
low = find_correction_dim(beta_fast)
|
||
high = find_correction_dim(beta_slow)
|
||
if truncate:
|
||
low = math.floor(low)
|
||
high = math.ceil(high)
|
||
return max(low, 0), min(high, dim - 1)
|
||
|
||
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
|
||
if min_ == max_:
|
||
max_ += 0.001
|
||
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
|
||
return torch.clamp(linear, 0, 1)
|
||
|
||
def get_mscale(scale: float) -> float:
|
||
if scale <= 1:
|
||
return 1.0
|
||
return 0.1 * math.log(scale) + 1.0
|
||
|
||
attention_scaling = get_mscale(factor)
|
||
|
||
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||
|
||
low, high = find_correction_range()
|
||
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
|
||
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
|
||
return inv_freq, attention_scaling
|
||
|
||
|
||
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||
pos_e = position_ids[:, None, :].float()
|
||
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
|
||
emb = torch.cat((freqs, freqs), dim=-1)
|
||
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
|
||
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
|
||
sin_split = sin.shape[-1] // 2
|
||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||
|
||
|
||
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
|
||
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
|
||
"""Attention with per-head sinks.
|
||
|
||
Sinks add a learned term to each row's softmax denominator but contribute
|
||
nothing to the output. We fake this by appending one zero k/v position and
|
||
putting the sink logit in the mask at that column.
|
||
"""
|
||
|
||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||
|
||
B, _, S_q, D = q.shape
|
||
H_kv = k.shape[1]
|
||
S_kv = k.shape[-2]
|
||
|
||
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||
|
||
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
|
||
if attention_mask is not None:
|
||
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
|
||
else:
|
||
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
|
||
mask = torch.cat([mask_left, sinks_col], dim=-1)
|
||
|
||
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
|
||
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
|
||
|
||
|
||
class GptOssAttention(nn.Module):
|
||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||
super().__init__()
|
||
self.layer_idx = layer_idx
|
||
self.layer_type = config.layer_types[layer_idx]
|
||
self.num_heads = config.num_attention_heads
|
||
self.num_kv_heads = config.num_key_value_heads
|
||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||
self.head_dim = config.head_dim
|
||
self.hidden_size = config.hidden_size
|
||
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
||
|
||
bias = config.attention_bias
|
||
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
|
||
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
|
||
|
||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
|
||
B, S, _ = hidden_states.shape
|
||
|
||
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||
|
||
q, k = apply_rope(q, k, freqs_cis)
|
||
|
||
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
|
||
return self.o_proj(out)
|
||
|
||
|
||
# Mixture of Experts
|
||
|
||
class GptOssTopKRouter(nn.Module):
|
||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||
super().__init__()
|
||
self.top_k = config.num_experts_per_tok
|
||
self.num_experts = config.num_local_experts
|
||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||
|
||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
|
||
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
|
||
logits = F.linear(hidden_states, weight, bias)
|
||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||
# Softmax over top-k slice only
|
||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||
return scores, top_idx
|
||
|
||
|
||
class GptOssExperts(nn.Module):
|
||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||
super().__init__()
|
||
self.num_experts = config.num_local_experts
|
||
self.hidden_size = config.hidden_size
|
||
self.intermediate_size = config.intermediate_size
|
||
self.alpha = config.moe_alpha
|
||
self.limit = config.moe_limit
|
||
|
||
E = self.num_experts
|
||
H = self.hidden_size
|
||
I = self.intermediate_size
|
||
|
||
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||
|
||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||
gate = gate_up[..., ::2]
|
||
up = gate_up[..., 1::2]
|
||
gate = gate.clamp(max=self.limit)
|
||
up = up.clamp(min=-self.limit, max=self.limit)
|
||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||
return torch.addcmul(glu, up, glu)
|
||
|
||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||
N = hidden_states.shape[0]
|
||
top_k = router_indices.shape[-1]
|
||
H = hidden_states.shape[-1]
|
||
|
||
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
|
||
|
||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||
|
||
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
|
||
self.down_proj.bank_resident(hidden_states) as down_bank:
|
||
for ei in expert_hit:
|
||
expert_idx = int(ei.item())
|
||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||
current = hidden_states[token_idx]
|
||
|
||
gate_up = gate_up_bank.expert_linear(current, expert_idx)
|
||
gated = self._apply_gate(gate_up)
|
||
expert_out = down_bank.expert_linear(gated, expert_idx)
|
||
|
||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||
|
||
flat_idx = token_idx * top_k + top_k_pos
|
||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||
|
||
return per_pair.view(N, top_k, H).sum(dim=1)
|
||
|
||
|
||
class GptOssMLP(nn.Module):
|
||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||
super().__init__()
|
||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
B, S, H = hidden_states.shape
|
||
flat = hidden_states.reshape(-1, H)
|
||
scores, idx = self.router(flat)
|
||
out = self.experts(flat, idx, scores)
|
||
return out.reshape(B, S, H)
|
||
|
||
|
||
# Decoder layer + model
|
||
|
||
class GptOssDecoderLayer(nn.Module):
|
||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||
super().__init__()
|
||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||
self.layer_type = config.layer_types[layer_idx]
|
||
|
||
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
|
||
residual = x
|
||
x = self.input_layernorm(x)
|
||
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
|
||
x = residual + x
|
||
|
||
residual = x
|
||
x = self.post_attention_layernorm(x)
|
||
x = self.mlp(x)
|
||
x = residual + x
|
||
return x
|
||
|
||
|
||
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||
neg = torch.finfo(dtype).min
|
||
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
|
||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||
if key_padding_mask is not None:
|
||
kp = key_padding_mask.to(dtype=dtype)
|
||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||
mask = mask + kp
|
||
return mask
|
||
|
||
|
||
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||
neg = torch.finfo(dtype).min
|
||
i = torch.arange(S, device=device).view(-1, 1)
|
||
j = torch.arange(S, device=device).view(1, -1)
|
||
keep = (j <= i) & (j > i - window)
|
||
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
|
||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||
if key_padding_mask is not None:
|
||
kp = key_padding_mask.to(dtype=dtype)
|
||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||
mask = mask + kp
|
||
return mask
|
||
|
||
|
||
class GptOssModel(nn.Module):
|
||
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
|
||
|
||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||
super().__init__()
|
||
self.config = config
|
||
self.dtype = dtype
|
||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||
self.layers = nn.ModuleList(
|
||
[
|
||
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
|
||
for i in range(config.num_hidden_layers)
|
||
]
|
||
)
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||
|
||
# Always build on CPU so the buffer survives meta-device construction.
|
||
inv_freq, attn_scaling = _yarn_inv_freq(
|
||
head_dim=config.head_dim,
|
||
base=config.rope_theta,
|
||
factor=config.rope_factor,
|
||
beta_fast=config.rope_beta_fast,
|
||
beta_slow=config.rope_beta_slow,
|
||
original_max_position_embeddings=config.original_max_position_embeddings,
|
||
truncate=config.rope_truncate,
|
||
device=torch.device("cpu"),
|
||
)
|
||
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
|
||
self.rope_attention_scaling = float(attn_scaling)
|
||
|
||
@property
|
||
def num_layers(self) -> int:
|
||
return self.config.num_hidden_layers
|
||
|
||
def get_input_embeddings(self):
|
||
return self.embed_tokens
|
||
|
||
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
|
||
) -> dict[str, torch.Tensor]:
|
||
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
|
||
masks = {"full_attention": full}
|
||
if any(t == "sliding_attention" for t in self.config.layer_types):
|
||
masks["sliding_attention"] = _make_sliding_causal_mask(
|
||
B, S, self.config.sliding_window, attention_mask, dtype, device
|
||
)
|
||
return masks
|
||
|
||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
|
||
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
|
||
B, S = input_ids.shape
|
||
device = input_ids.device
|
||
dtype = self.dtype
|
||
|
||
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
|
||
|
||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
|
||
|
||
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
|
||
|
||
capture_layers = list(capture_layers) if capture_layers else None
|
||
if capture_layers:
|
||
max_layer = max(capture_layers)
|
||
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
|
||
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
|
||
else:
|
||
max_layer = self.config.num_hidden_layers - 1
|
||
wanted = None
|
||
captured = None
|
||
|
||
for i, layer in enumerate(self.layers):
|
||
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
|
||
if wanted is not None and i in wanted:
|
||
captured[wanted[i]] = hidden_states
|
||
if i >= max_layer:
|
||
break
|
||
|
||
if captured is not None:
|
||
return {"hidden_states": captured}
|
||
return {"last_hidden_state": self.norm(hidden_states)}
|
||
|
||
|
||
# Lens chat-template constants (verbatim from the reference pipeline).
|
||
_LENS_CHAT_SYSTEM = (
|
||
"Describe the image by detailing the color, shape, size, texture, "
|
||
"quantity, text, spatial relationships of the objects and background."
|
||
)
|
||
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
|
||
LENS_TXT_OFFSET = 97
|
||
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
|
||
LENS_MAX_TOKENS = 512
|
||
|
||
|
||
# The reference GPT-OSS Harmony template injects today's date here
|
||
_LENS_CHAT_DATE = "2026-05-23"
|
||
|
||
|
||
def _lens_render_chat(prompt: str) -> str:
|
||
"""Render the Lens prompt in GPT-OSS Harmony format."""
|
||
return (
|
||
f"<|start|>system<|message|>"
|
||
f"You are ChatGPT, a large language model trained by OpenAI.\n"
|
||
f"Knowledge cutoff: 2024-06\n"
|
||
f"Current date: {_LENS_CHAT_DATE}\n\n"
|
||
f"Reasoning: medium\n\n"
|
||
f"# Valid channels: analysis, commentary, final. "
|
||
f"Channel must be included for every message.<|end|>"
|
||
f"<|start|>developer<|message|># Instructions\n\n"
|
||
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
|
||
f"<|start|>user<|message|>{prompt}<|end|>"
|
||
f"<|start|>assistant<|channel|>analysis<|message|>"
|
||
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
|
||
f"<|start|>assistant<|channel|>final<|message|>"
|
||
)
|
||
|
||
|
||
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
|
||
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
|
||
|
||
|
||
class _GptOssRawTokenizer:
|
||
"""Raw ``tokenizers.Tokenizer`` wrapper.
|
||
|
||
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
|
||
(``tokenizer_json`` key) rather than as a committed file. Extracted
|
||
it in ``sd.py`` and passes it here via ``tokenizer_data``.
|
||
"""
|
||
|
||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||
from tokenizers import Tokenizer
|
||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||
if tokenizer_json_bytes is None:
|
||
raise ValueError(
|
||
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
|
||
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
|
||
"embeds the tokenizer."
|
||
)
|
||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||
|
||
@classmethod
|
||
def from_pretrained(cls, tokenizer_data, **kwargs):
|
||
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
|
||
|
||
def __call__(self, text):
|
||
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
|
||
|
||
def get_vocab(self):
|
||
return self.tokenizer.get_vocab()
|
||
|
||
def convert_tokens_to_ids(self, tokens):
|
||
return [self.tokenizer.token_to_id(t) for t in tokens]
|
||
|
||
def decode(self, ids, **kwargs):
|
||
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
|
||
|
||
|
||
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
|
||
tokenizer_json_data = None
|
||
|
||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||
self.tokenizer_json_data = tokenizer_json
|
||
super().__init__(
|
||
tokenizer_json,
|
||
embedding_directory=embedding_directory,
|
||
pad_with_end=False,
|
||
embedding_size=2880,
|
||
embedding_key="gpt_oss",
|
||
tokenizer_class=_GptOssRawTokenizer,
|
||
has_start_token=False,
|
||
has_end_token=False,
|
||
pad_to_max_length=False,
|
||
max_length=99999999,
|
||
min_length=1,
|
||
pad_left=False,
|
||
disable_weights=True,
|
||
tokenizer_data=tokenizer_data,
|
||
)
|
||
self.pad_token_id = _LENS_PAD_TOKEN_ID
|
||
|
||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
|
||
if not text or not text.strip():
|
||
return [[]]
|
||
rendered = _lens_render_chat(text)
|
||
ids = self.tokenizer(rendered)["input_ids"]
|
||
if len(ids) > LENS_MAX_TOKENS:
|
||
ids = ids[:LENS_MAX_TOKENS]
|
||
return [[(int(t), 1.0) for t in ids]]
|
||
|
||
def state_dict(self):
|
||
if self.tokenizer_json_data is not None:
|
||
return {"tokenizer_json": self.tokenizer_json_data}
|
||
return {}
|
||
|
||
|
||
class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||
super().__init__(
|
||
embedding_directory=embedding_directory,
|
||
tokenizer_data=tokenizer_data,
|
||
name="gpt_oss",
|
||
tokenizer=LensGptOssTokenizer,
|
||
)
|
||
|
||
|
||
class LensGptOssClipModel(nn.Module):
|
||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||
|
||
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||
super().__init__()
|
||
model_options = dict(model_options or {})
|
||
|
||
operations = model_options.get("custom_operations")
|
||
if operations is None:
|
||
quant_config = model_options.get("quantization_metadata") or {}
|
||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||
self.operations = operations
|
||
|
||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||
self.config = GptOss20BConfig(**cfg_overrides)
|
||
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||
|
||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||
self.num_layers = self.config.num_hidden_layers
|
||
self.dtype = dtype
|
||
self.execution_device = None
|
||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||
|
||
def set_clip_options(self, options):
|
||
self.execution_device = options.get("execution_device", self.execution_device)
|
||
|
||
def reset_clip_options(self):
|
||
self.execution_device = None
|
||
|
||
def _gather_tokens(self, token_weight_pairs):
|
||
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
|
||
pad_id = self._pad_token_id
|
||
max_len = max(len(x) for x in ids_list)
|
||
device = self.execution_device
|
||
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
|
||
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
|
||
for i, x in enumerate(ids_list):
|
||
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
|
||
mask[i, : len(x)] = 1
|
||
return ids, mask
|
||
|
||
def encode_token_weights(self, token_weight_pairs):
|
||
# Empty negative: emit zero-length features + zero mask
|
||
if all(len(batch) == 0 for batch in token_weight_pairs):
|
||
device = self.execution_device
|
||
B = len(token_weight_pairs)
|
||
L = len(self.selected_layers)
|
||
H = self.config.hidden_size
|
||
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
|
||
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
|
||
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
|
||
|
||
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
|
||
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
|
||
layers = out["hidden_states"] # list of L × [B, S, H]
|
||
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
|
||
|
||
offset = self.txt_offset
|
||
if stacked.shape[1] > offset:
|
||
stacked = stacked[:, offset:].contiguous()
|
||
mask_trim = attn_mask[:, offset:]
|
||
else:
|
||
stacked = stacked[:, :0]
|
||
mask_trim = attn_mask[:, :0]
|
||
|
||
B, S, L, H = stacked.shape
|
||
flat = stacked.reshape(B, S, L * H)
|
||
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
|
||
return flat, None, extra
|
||
|
||
def load_sd(self, sd):
|
||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||
|
||
|
||
class LensTEModel(sd1_clip.SD1ClipModel):
|
||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||
|
||
|
||
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||
class LensTEModel_(LensTEModel):
|
||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||
mo = dict(model_options or {})
|
||
if llama_quantization_metadata is not None:
|
||
mo["quantization_metadata"] = llama_quantization_metadata
|
||
if dtype is None and dtype_llama is not None:
|
||
dtype = dtype_llama
|
||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||
|
||
return LensTEModel_
|