vectorized implementation of moe/fixes for issues

This commit is contained in:
Yousef Rafat 2025-10-31 23:53:13 +02:00
parent de43880bdb
commit a2fff60d4c
2 changed files with 59 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import math
import torch
import psutil
import torch.nn as nn
from pathlib import Path
from einops import rearrange
import torch.nn.functional as F
from collections import OrderedDict
@ -460,13 +461,13 @@ class HunyuanMLP(nn.Module):
self.intermediate_size = 3072
self.act_fn = torch.nn.functional.silu
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
x = x.unsqueeze(0)
self.intermediate_size *= 2 # SwiGLU
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
def forward(self, x):
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
x = x.unsqueeze(0)
gate_and_up_proj = self.gate_and_up_proj(x)
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
down_proj = self.down_proj(x1 * self.act_fn(x2))
@ -654,7 +655,9 @@ class LazyMoELoader(nn.Module):
self.device = device
def lazy_init(self, config, layer_idx, expert_idx):
checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors"
comfyui_dir = Path.home() / "ComfyUI"
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
checkpoint = checkpoint.resolve()
if not os.path.exists(checkpoint):
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
@ -720,34 +723,65 @@ class HunyuanMoE(nn.Module):
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
combined_output = torch.zeros_like(reshaped_input)
experts_list = []
for e in range(self.num_experts):
token_mask = (expert_index == e)
if not token_mask.any():
continue
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
expert = expert.lazy_init(self.config, self.layer_idx, e)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
experts_list.append((e, expert))
per_pos, per_tokens, per_weights = [], [], []
for e, _ in experts_list:
token_mask = (expert_index == e)
token_ids = token_mask.nonzero(as_tuple=False)
token_positions = token_ids[:, 0]
topk_slot = token_ids[:, 1]
tokens = reshaped_input[token_positions]
weights = expert_weight[token_positions, topk_slot]
if self.experts is not None and INIT_MOE:
out = self.experts[e](tokens)
elif self.experts is None:
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
out = expert.lazy_init(self.config, self.layer_idx, e)(tokens)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
else:
tokens = tokens.to(next(expert.parameters()).device)
out = expert(tokens.view(bsz, -1, hidden_size))
per_pos.append(token_positions)
per_tokens.append(tokens)
per_weights.append(weights)
out = out * weights.to(out.device).unsqueeze(-1)
lengths = [t.shape[0] for t in per_tokens]
E = len(per_tokens)
L = max(lengths)
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
for i, t in enumerate(per_tokens):
tokens_padded[i, : t.shape[0]] = t
weights_padded[i, : t.shape[0]] = per_weights[i]
l1, l2 = [], []
for _, expert in experts_list:
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2)
x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x)
out_padded = torch.bmm(x, W2_T)
out_padded = out_padded * weights_padded.unsqueeze(-1)
for i, token_positions in enumerate(per_pos):
Ni = lengths[i]
out_i = out_padded[i, :Ni]
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
#expert_outputs = []
@ -1014,12 +1048,12 @@ class HunyuanImage3ForCausalMM(nn.Module):
dtype=x.dtype,
)
image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool)
image_mask = torch.ones(x.size(1))
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
gen_timestep_scatter_index = 4
cond, uncond = condition[:4], condition[4:]
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
height, width = x.shape[2] * 16, x.shape[3] * 16
token_height = height // (16 * 16)

View File

@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.hunyuan_image_3.model
import comfy.model_management
import comfy.patcher_extension
@ -1196,6 +1197,10 @@ class Hunyuan3Dv2(BaseModel):
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HunyuanImage3(BaseModel):
def __init__(self, model_config, model_type=ModelType.Flow, device=None):
super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM)
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):