mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
vectrozied correct implementation of moe forward
This commit is contained in:
parent
4a5509a4c5
commit
61b1efdaf0
@ -280,23 +280,43 @@ def conv_nd(dims, *args, **kwargs):
|
|||||||
def normalization(channels, **kwargs):
|
def normalization(channels, **kwargs):
|
||||||
return nn.GroupNorm(32, channels, **kwargs)
|
return nn.GroupNorm(32, channels, **kwargs)
|
||||||
|
|
||||||
def topkgating(
|
def topkgating(logits: torch.Tensor, topk: int):
|
||||||
logits: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
norm_topk_prob: bool = True,
|
|
||||||
):
|
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
gates = F.softmax(logits, dim=1)
|
gates = F.softmax(logits, dim=1)
|
||||||
|
|
||||||
values_all, indices_all = torch.topk(gates, topk, dim=1)
|
num_experts = int(gates.shape[1])
|
||||||
expert_weight = values_all[:, :topk]
|
|
||||||
expert_index = indices_all[:, :topk]
|
|
||||||
|
|
||||||
if norm_topk_prob and topk > 1:
|
_, expert_index = torch.topk(gates, topk)
|
||||||
denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps)
|
expert_mask = F.one_hot(expert_index, num_experts)
|
||||||
expert_weight = expert_weight / denom
|
|
||||||
|
|
||||||
return expert_weight, expert_index
|
expert_index_flat = expert_index.flatten()
|
||||||
|
tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts)
|
||||||
|
expert_capacity = torch.max(tokens_per_expert).item()
|
||||||
|
|
||||||
|
gates_s = torch.clamp(
|
||||||
|
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
|
||||||
|
)
|
||||||
|
router_probs = gates / gates_s
|
||||||
|
|
||||||
|
expert_index = torch.transpose(expert_index, 0, 1)
|
||||||
|
expert_index = expert_index.reshape(-1)
|
||||||
|
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
|
||||||
|
|
||||||
|
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
|
||||||
|
token_priority = token_priority.reshape((topk, -1, num_experts))
|
||||||
|
token_priority = torch.transpose(token_priority, 0, 1)
|
||||||
|
|
||||||
|
token_priority = torch.max(token_priority, dim=1)[0]
|
||||||
|
|
||||||
|
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
|
||||||
|
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
|
||||||
|
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
|
||||||
|
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
|
||||||
|
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
|
||||||
|
|
||||||
|
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||||
|
|
||||||
|
return combine_weights, dispatch_mask
|
||||||
|
|
||||||
class HunyuanRMSNorm(nn.Module):
|
class HunyuanRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
@ -436,15 +456,13 @@ class HunyuanTopKGate(nn.Module):
|
|||||||
num_experts = 64
|
num_experts = 64
|
||||||
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32)
|
||||||
|
|
||||||
self.norm_topk_prob = True
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
bsz, seq_len, hidden_size = hidden_states.shape
|
bsz, seq_len, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.reshape(-1, hidden_size)
|
hidden_states = hidden_states.reshape(-1, hidden_size)
|
||||||
if self.wg.weight.dtype == torch.float32:
|
if self.wg.weight.dtype == torch.float32:
|
||||||
hidden_states = hidden_states.float()
|
hidden_states = hidden_states.float()
|
||||||
logits = self.wg(hidden_states)
|
logits = self.wg(hidden_states)
|
||||||
gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,)
|
gate_output = topkgating(logits, self.moe_topk)
|
||||||
|
|
||||||
return gate_output
|
return gate_output
|
||||||
|
|
||||||
@ -645,46 +663,33 @@ class HunyuanMoE(nn.Module):
|
|||||||
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
reshaped_input = hidden_states.reshape(-1, hidden_size)
|
||||||
|
|
||||||
with torch.cuda.nvtx.range("MoE"):
|
with torch.cuda.nvtx.range("MoE"):
|
||||||
expert_weight, expert_index = self.gate(hidden_states)
|
combine_weights, dispatch_mask = self.gate(hidden_states)
|
||||||
|
|
||||||
|
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input)
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
dtype = reshaped_input.dtype
|
dtype = reshaped_input.dtype
|
||||||
|
|
||||||
|
used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0)
|
||||||
|
used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist()
|
||||||
|
|
||||||
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype)
|
||||||
|
|
||||||
per_pos = [None] * self.num_experts
|
if len(used_indices) == 0:
|
||||||
per_tokens = [None] * self.num_experts
|
|
||||||
per_weights = [None] * self.num_experts
|
|
||||||
for e in range(self.num_experts):
|
|
||||||
token_mask = (expert_index == e)
|
|
||||||
token_ids = token_mask.nonzero(as_tuple=False)
|
|
||||||
if token_ids.numel() == 0:
|
|
||||||
continue
|
|
||||||
token_positions = token_ids[:, 0]
|
|
||||||
topk_slot = token_ids[:, 1]
|
|
||||||
per_pos[e] = token_positions
|
|
||||||
per_tokens[e] = reshaped_input[token_positions]
|
|
||||||
per_weights[e] = expert_weight[token_positions, topk_slot]
|
|
||||||
|
|
||||||
used = [i for i, t in enumerate(per_tokens) if t is not None]
|
|
||||||
if len(used) == 0:
|
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
tokens_list = [per_tokens[i] for i in used]
|
tokens_padded = dispatched_input[used_indices]
|
||||||
weights_list = [per_weights[i] for i in used]
|
|
||||||
lengths = [t.shape[0] for t in tokens_list]
|
|
||||||
U = len(tokens_list)
|
|
||||||
L = max(lengths)
|
|
||||||
H = hidden_size
|
|
||||||
|
|
||||||
tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype)
|
l1_layers, l2_layers = [], []
|
||||||
weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype)
|
for i in used_indices:
|
||||||
for idx, t in enumerate(tokens_list):
|
expert = self.experts[i]
|
||||||
n = t.shape[0]
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||||
tokens_padded[idx, :n] = t
|
expert = expert.result()
|
||||||
weights_padded[idx, :n] = weights_list[idx]
|
expert = expert.to(device)
|
||||||
|
l1_layers.append(expert.gate_and_up_proj)
|
||||||
|
l2_layers.append(expert.down_proj)
|
||||||
|
|
||||||
l1, l2 = [], []
|
l1, l2 = [], []
|
||||||
for i in used:
|
for i in used_indices:
|
||||||
expert = self.experts[i]
|
expert = self.experts[i]
|
||||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
||||||
expert = expert.result()
|
expert = expert.result()
|
||||||
@ -708,19 +713,18 @@ class HunyuanMoE(nn.Module):
|
|||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
x = torch.bmm(tokens_padded, W1_T)
|
x = torch.bmm(tokens_padded, W1_T)
|
||||||
x = F.silu(x)
|
|
||||||
x1, x2 = x.chunk(2, dim=2)
|
x1, x2 = x.chunk(2, dim=2)
|
||||||
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
gated = x1 * F.silu(x2)
|
||||||
|
out_padded = torch.bmm(gated, W2_T)
|
||||||
|
|
||||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
combine_weights_used = combine_weights[:, used_indices, :]
|
||||||
|
|
||||||
for idx, orig_expert_idx in enumerate(used):
|
combined_output = torch.einsum("suc,ucm->sm",
|
||||||
pos = per_pos[orig_expert_idx]
|
combine_weights_used.type_as(out_padded),
|
||||||
n = lengths[idx]
|
out_padded
|
||||||
out_i = out_padded[idx, :n]
|
)
|
||||||
combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype))
|
|
||||||
|
|
||||||
del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded
|
del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded
|
||||||
|
|
||||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user