mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 18:20:26 +08:00
vectrozied correct implementation of moe forward
This commit is contained in:
parent
4a5509a4c5
commit
61b1efdaf0
@ -280,23 +280,43 @@ def conv_nd(dims, *args, **kwargs):
|
||||
def normalization(channels, **kwargs):
|
||||
return nn.GroupNorm(32, channels, **kwargs)
|
||||
|
||||
def topkgating(
|
||||
logits: torch.Tensor,
|
||||
topk: int,
|
||||
norm_topk_prob: bool = True,
|
||||
):
|
||||
def topkgating(logits: torch.Tensor, topk: int):
|
||||
logits = logits.float()
|
||||
gates = F.softmax(logits, dim=1)
|
||||
|
||||
values_all, indices_all = torch.topk(gates, topk, dim=1)
|
||||
expert_weight = values_all[:, :topk]
|
||||
expert_index = indices_all[:, :topk]
|
||||
num_experts = int(gates.shape[1])
|
||||
|
||||
if norm_topk_prob and topk > 1:
|
||||
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
||||
expert_weight = expert_weight / denom
|
||||
_, expert_index = torch.topk(gates, topk)
|
||||
expert_mask = F.one_hot(expert_index, num_experts)
|
||||
|
||||
return expert_weight, expert_index
|
||||
expert_index_flat = expert_index.flatten()
|
||||
tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts)
|
||||
expert_capacity = torch.max(tokens_per_expert).item()
|
||||
|
||||
gates_s = torch.clamp(
|
||||
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
|
||||
)
|
||||
router_probs = gates / gates_s
|
||||
|
||||
expert_index = torch.transpose(expert_index, 0, 1)
|
||||
expert_index = expert_index.reshape(-1)
|
||||
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
||||
|
||||
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
|
||||
token_priority = token_priority.reshape((topk, -1, num_experts))
|
||||
token_priority = torch.transpose(token_priority, 0, 1)
|
||||
|
||||
token_priority = torch.max(token_priority, dim=1)[0]
|
||||
|
||||
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
||||
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
||||
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
||||
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
|
||||
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
||||
|
||||
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||
|
||||
return combine_weights, dispatch_mask
|
||||
|
||||
class HunyuanRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
@ -436,15 +456,13 @@ class HunyuanTopKGate(nn.Module):
|
||||
num_experts = 64
|
||||
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
||||
|
||||
self.norm_topk_prob = True
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_size)
|
||||
if self.wg.weight.dtype == torch.float32:
|
||||
hidden_states = hidden_states.float()
|
||||
logits = self.wg(hidden_states)
|
||||
gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,)
|
||||
gate_output = topkgating(logits, self.moe_topk)
|
||||
|
||||
return gate_output
|
||||
|
||||
@ -645,46 +663,33 @@ class HunyuanMoE(nn.Module):
|
||||
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
||||
|
||||
with torch.cuda.nvtx.range("MoE"):
|
||||
expert_weight, expert_index = self.gate(hidden_states)
|
||||
combine_weights, dispatch_mask = self.gate(hidden_states)
|
||||
|
||||
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input)
|
||||
device = hidden_states.device
|
||||
dtype = reshaped_input.dtype
|
||||
|
||||
used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0)
|
||||
used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist()
|
||||
|
||||
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
||||
|
||||
per_pos = [None] * self.num_experts
|
||||
per_tokens = [None] * self.num_experts
|
||||
per_weights = [None] * self.num_experts
|
||||
for e in range(self.num_experts):
|
||||
token_mask = (expert_index == e)
|
||||
token_ids = token_mask.nonzero(as_tuple=False)
|
||||
if token_ids.numel() == 0:
|
||||
continue
|
||||
token_positions = token_ids[:, 0]
|
||||
topk_slot = token_ids[:, 1]
|
||||
per_pos[e] = token_positions
|
||||
per_tokens[e] = reshaped_input[token_positions]
|
||||
per_weights[e] = expert_weight[token_positions, topk_slot]
|
||||
|
||||
used = [i for i, t in enumerate(per_tokens) if t is not None]
|
||||
if len(used) == 0:
|
||||
if len(used_indices) == 0:
|
||||
pass
|
||||
else:
|
||||
tokens_list = [per_tokens[i] for i in used]
|
||||
weights_list = [per_weights[i] for i in used]
|
||||
lengths = [t.shape[0] for t in tokens_list]
|
||||
U = len(tokens_list)
|
||||
L = max(lengths)
|
||||
H = hidden_size
|
||||
tokens_padded = dispatched_input[used_indices]
|
||||
|
||||
tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype)
|
||||
weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype)
|
||||
for idx, t in enumerate(tokens_list):
|
||||
n = t.shape[0]
|
||||
tokens_padded[idx, :n] = t
|
||||
weights_padded[idx, :n] = weights_list[idx]
|
||||
l1_layers, l2_layers = [], []
|
||||
for i in used_indices:
|
||||
expert = self.experts[i]
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
expert = expert.to(device)
|
||||
l1_layers.append(expert.gate_and_up_proj)
|
||||
l2_layers.append(expert.down_proj)
|
||||
|
||||
l1, l2 = [], []
|
||||
for i in used:
|
||||
for i in used_indices:
|
||||
expert = self.experts[i]
|
||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||
expert = expert.result()
|
||||
@ -708,19 +713,18 @@ class HunyuanMoE(nn.Module):
|
||||
time.sleep(0.001)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
||||
gated = x1 * F.silu(x2)
|
||||
out_padded = torch.bmm(gated, W2_T)
|
||||
|
||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||
combine_weights_used = combine_weights[:, used_indices, :]
|
||||
|
||||
for idx, orig_expert_idx in enumerate(used):
|
||||
pos = per_pos[orig_expert_idx]
|
||||
n = lengths[idx]
|
||||
out_i = out_padded[idx, :n]
|
||||
combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype))
|
||||
combined_output = torch.einsum("suc,ucm->sm",
|
||||
combine_weights_used.type_as(out_padded),
|
||||
out_padded
|
||||
)
|
||||
|
||||
del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded
|
||||
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
|
||||
|
||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user