diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py index 4e15bcedd..2453d8d74 100644 --- a/comfy/text_encoders/gpt_oss.py +++ b/comfy/text_encoders/gpt_oss.py @@ -48,16 +48,8 @@ class GptOss20BConfig: ] -def _yarn_inv_freq( - head_dim: int, - base: float, - factor: float, - beta_fast: float, - beta_slow: float, - original_max_position_embeddings: int, - truncate: bool, - device=None, -) -> tuple[torch.Tensor, float]: +def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float, + original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]: """YARN inv_freq + attention scaling (matches transformers).""" dim = head_dim @@ -109,15 +101,8 @@ def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ return cos, sin[..., :sin_split], -sin[..., sin_split:] -def _attention_with_sinks( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sinks: torch.Tensor, - attention_mask: Optional[torch.Tensor], - num_heads: int, - num_kv_groups: int, -) -> torch.Tensor: +def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor, + attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor: """Attention with per-head sinks. Sinks add a learned term to each row's softmax denominator but contribute @@ -186,14 +171,15 @@ class GptOssTopKRouter(nn.Module): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - # Raw Parameters (not Linear) to match HF state-dict keys. self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype)) self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype)) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - logits = F.linear(hidden_states, self.weight, self.bias) + weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False) + bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False) + logits = F.linear(hidden_states, weight, bias) top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1) - # Softmax over top-k slice only (matches transformers), not all experts. + # Softmax over top-k slice only scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype) return scores, top_idx @@ -220,7 +206,7 @@ class GptOssExperts(nn.Module): gate = gate.clamp(max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) - return (up + 1) * glu + return torch.addcmul(glu, up, glu) def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: N = hidden_states.shape[0] @@ -543,9 +529,6 @@ class LensGptOssClipModel(nn.Module): self.execution_device = None self._pad_token_id = _LENS_PAD_TOKEN_ID - for p in self.parameters(): - p.requires_grad = False - def set_clip_options(self, options): self.execution_device = options.get("execution_device", self.execution_device)