Fix fp16 TE usage, some cleanup

This commit is contained in:
kijai 2026-05-24 12:40:17 +03:00
parent a49d578f9e
commit 0b0f1b1cf6

View File

@ -48,16 +48,8 @@ class GptOss20BConfig:
] ]
def _yarn_inv_freq( def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
head_dim: int, original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
base: float,
factor: float,
beta_fast: float,
beta_slow: float,
original_max_position_embeddings: int,
truncate: bool,
device=None,
) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers).""" """YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim dim = head_dim
@ -109,15 +101,8 @@ def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_
return cos, sin[..., :sin_split], -sin[..., sin_split:] return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks( def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
q: torch.Tensor, attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
k: torch.Tensor,
v: torch.Tensor,
sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor],
num_heads: int,
num_kv_groups: int,
) -> torch.Tensor:
"""Attention with per-head sinks. """Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute Sinks add a learned term to each row's softmax denominator but contribute
@ -186,14 +171,15 @@ class GptOssTopKRouter(nn.Module):
super().__init__() super().__init__()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
# Raw Parameters (not Linear) to match HF state-dict keys.
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype)) self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype)) self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
logits = F.linear(hidden_states, self.weight, self.bias) weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
logits = F.linear(hidden_states, weight, bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1) top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only (matches transformers), not all experts. # Softmax over top-k slice only
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype) scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx return scores, top_idx
@ -220,7 +206,7 @@ class GptOssExperts(nn.Module):
gate = gate.clamp(max=self.limit) gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha) glu = gate * torch.sigmoid(gate * self.alpha)
return (up + 1) * glu return torch.addcmul(glu, up, glu)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0] N = hidden_states.shape[0]
@ -543,9 +529,6 @@ class LensGptOssClipModel(nn.Module):
self.execution_device = None self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID self._pad_token_id = _LENS_PAD_TOKEN_ID
for p in self.parameters():
p.requires_grad = False
def set_clip_options(self, options): def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device) self.execution_device = options.get("execution_device", self.execution_device)