mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-29 02:17:52 +08:00
Fix fp16 TE usage, some cleanup
This commit is contained in:
parent
a49d578f9e
commit
0b0f1b1cf6
@ -48,16 +48,8 @@ class GptOss20BConfig:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _yarn_inv_freq(
|
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
|
||||||
head_dim: int,
|
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
|
||||||
base: float,
|
|
||||||
factor: float,
|
|
||||||
beta_fast: float,
|
|
||||||
beta_slow: float,
|
|
||||||
original_max_position_embeddings: int,
|
|
||||||
truncate: bool,
|
|
||||||
device=None,
|
|
||||||
) -> tuple[torch.Tensor, float]:
|
|
||||||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||||||
dim = head_dim
|
dim = head_dim
|
||||||
|
|
||||||
@ -109,15 +101,8 @@ def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_
|
|||||||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||||||
|
|
||||||
|
|
||||||
def _attention_with_sinks(
|
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
|
||||||
q: torch.Tensor,
|
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
sinks: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_groups: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Attention with per-head sinks.
|
"""Attention with per-head sinks.
|
||||||
|
|
||||||
Sinks add a learned term to each row's softmax denominator but contribute
|
Sinks add a learned term to each row's softmax denominator but contribute
|
||||||
@ -186,14 +171,15 @@ class GptOssTopKRouter(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
# Raw Parameters (not Linear) to match HF state-dict keys.
|
|
||||||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||||||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
logits = F.linear(hidden_states, self.weight, self.bias)
|
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
|
||||||
|
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
|
||||||
|
logits = F.linear(hidden_states, weight, bias)
|
||||||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||||||
# Softmax over top-k slice only (matches transformers), not all experts.
|
# Softmax over top-k slice only
|
||||||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||||||
return scores, top_idx
|
return scores, top_idx
|
||||||
|
|
||||||
@ -220,7 +206,7 @@ class GptOssExperts(nn.Module):
|
|||||||
gate = gate.clamp(max=self.limit)
|
gate = gate.clamp(max=self.limit)
|
||||||
up = up.clamp(min=-self.limit, max=self.limit)
|
up = up.clamp(min=-self.limit, max=self.limit)
|
||||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||||
return (up + 1) * glu
|
return torch.addcmul(glu, up, glu)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||||
N = hidden_states.shape[0]
|
N = hidden_states.shape[0]
|
||||||
@ -543,9 +529,6 @@ class LensGptOssClipModel(nn.Module):
|
|||||||
self.execution_device = None
|
self.execution_device = None
|
||||||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||||||
|
|
||||||
for p in self.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
self.execution_device = options.get("execution_device", self.execution_device)
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user