diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 6b292ed7f..8a01cd155 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -280,23 +280,43 @@ def conv_nd(dims, *args, **kwargs): def normalization(channels, **kwargs): return nn.GroupNorm(32, channels, **kwargs) -def topkgating( - logits: torch.Tensor, - topk: int, - norm_topk_prob: bool = True, -): +def topkgating(logits: torch.Tensor, topk: int): logits = logits.float() gates = F.softmax(logits, dim=1) - values_all, indices_all = torch.topk(gates, topk, dim=1) - expert_weight = values_all[:, :topk] - expert_index = indices_all[:, :topk] + num_experts = int(gates.shape[1]) - if norm_topk_prob and topk > 1: - denom = expert_weight.sum(dim=1, keepdim=True).clamp_min(torch.finfo(gates.dtype).eps) - expert_weight = expert_weight / denom + _, expert_index = torch.topk(gates, topk) + expert_mask = F.one_hot(expert_index, num_experts) - return expert_weight, expert_index + expert_index_flat = expert_index.flatten() + tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts) + expert_capacity = torch.max(tokens_per_expert).item() + + gates_s = torch.clamp( + torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps + ) + router_probs = gates / gates_s + + expert_index = torch.transpose(expert_index, 0, 1) + expert_index = expert_index.reshape(-1) + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 + token_priority = token_priority.reshape((topk, -1, num_experts)) + token_priority = torch.transpose(token_priority, 0, 1) + + token_priority = torch.max(token_priority, dim=1)[0] + + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + + return combine_weights, dispatch_mask class HunyuanRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -436,15 +456,13 @@ class HunyuanTopKGate(nn.Module): num_experts = 64 self.wg = nn.Linear(config["hidden_size"], num_experts, bias=False, dtype=torch.float32) - self.norm_topk_prob = True - def forward(self, hidden_states): bsz, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_size) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) - gate_output = topkgating(logits, self.moe_topk, norm_topk_prob=self.norm_topk_prob,) + gate_output = topkgating(logits, self.moe_topk) return gate_output @@ -645,46 +663,33 @@ class HunyuanMoE(nn.Module): reshaped_input = hidden_states.reshape(-1, hidden_size) with torch.cuda.nvtx.range("MoE"): - expert_weight, expert_index = self.gate(hidden_states) + combine_weights, dispatch_mask = self.gate(hidden_states) + + dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(reshaped_input), reshaped_input) device = hidden_states.device dtype = reshaped_input.dtype + + used_mask = (dispatch_mask.sum(dim=(0, 2)) > 0) + used_indices = used_mask.nonzero(as_tuple=False).squeeze(1).tolist() combined_output = torch.zeros_like(reshaped_input, device=device, dtype=dtype) - per_pos = [None] * self.num_experts - per_tokens = [None] * self.num_experts - per_weights = [None] * self.num_experts - for e in range(self.num_experts): - token_mask = (expert_index == e) - token_ids = token_mask.nonzero(as_tuple=False) - if token_ids.numel() == 0: - continue - token_positions = token_ids[:, 0] - topk_slot = token_ids[:, 1] - per_pos[e] = token_positions - per_tokens[e] = reshaped_input[token_positions] - per_weights[e] = expert_weight[token_positions, topk_slot] - - used = [i for i, t in enumerate(per_tokens) if t is not None] - if len(used) == 0: + if len(used_indices) == 0: pass else: - tokens_list = [per_tokens[i] for i in used] - weights_list = [per_weights[i] for i in used] - lengths = [t.shape[0] for t in tokens_list] - U = len(tokens_list) - L = max(lengths) - H = hidden_size + tokens_padded = dispatched_input[used_indices] - tokens_padded = torch.zeros((U, L, H), device=device, dtype=dtype) - weights_padded = torch.zeros((U, L), device=device, dtype=weights_list[0].dtype) - for idx, t in enumerate(tokens_list): - n = t.shape[0] - tokens_padded[idx, :n] = t - weights_padded[idx, :n] = weights_list[idx] + l1_layers, l2_layers = [], [] + for i in used_indices: + expert = self.experts[i] + if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): + expert = expert.result() + expert = expert.to(device) + l1_layers.append(expert.gate_and_up_proj) + l2_layers.append(expert.down_proj) l1, l2 = [], [] - for i in used: + for i in used_indices: expert = self.experts[i] if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): expert = expert.result() @@ -708,19 +713,18 @@ class HunyuanMoE(nn.Module): time.sleep(0.001) x = torch.bmm(tokens_padded, W1_T) - x = F.silu(x) x1, x2 = x.chunk(2, dim=2) - out_padded = torch.bmm(x1 * F.silu(x2), W2_T) + gated = x1 * F.silu(x2) + out_padded = torch.bmm(gated, W2_T) - out_padded = out_padded * weights_padded.unsqueeze(-1) + combine_weights_used = combine_weights[:, used_indices, :] - for idx, orig_expert_idx in enumerate(used): - pos = per_pos[orig_expert_idx] - n = lengths[idx] - out_i = out_padded[idx, :n] - combined_output.index_add_(0, pos.to(device), out_i.to(combined_output.dtype)) + combined_output = torch.einsum("suc,ucm->sm", + combine_weights_used.type_as(out_padded), + out_padded + ) - del tokens_padded, weights_padded, W1, W2, W1_T, W2_T, x, x1, x2, out_padded + del tokens_padded, W1, W2, W1_T, W2_T, x, x1, x2, gated, out_padded combined_output = combined_output.reshape(bsz, seq_len, hidden_size)