vectorized implementation of moe/fixes for issues

This commit is contained in:
Yousef Rafat 2025-10-31 23:53:13 +02:00
parent de43880bdb
commit a2fff60d4c
2 changed files with 59 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import math
import torch import torch
import psutil import psutil
import torch.nn as nn import torch.nn as nn
from pathlib import Path
from einops import rearrange from einops import rearrange
import torch.nn.functional as F import torch.nn.functional as F
from collections import OrderedDict from collections import OrderedDict
@ -460,13 +461,13 @@ class HunyuanMLP(nn.Module):
self.intermediate_size = 3072 self.intermediate_size = 3072
self.act_fn = torch.nn.functional.silu self.act_fn = torch.nn.functional.silu
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
x = x.unsqueeze(0)
self.intermediate_size *= 2 # SwiGLU self.intermediate_size *= 2 # SwiGLU
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
def forward(self, x): def forward(self, x):
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
if x.ndim == 2:
x = x.unsqueeze(0)
gate_and_up_proj = self.gate_and_up_proj(x) gate_and_up_proj = self.gate_and_up_proj(x)
x1, x2 = gate_and_up_proj.chunk(2, dim=2) x1, x2 = gate_and_up_proj.chunk(2, dim=2)
down_proj = self.down_proj(x1 * self.act_fn(x2)) down_proj = self.down_proj(x1 * self.act_fn(x2))
@ -654,7 +655,9 @@ class LazyMoELoader(nn.Module):
self.device = device self.device = device
def lazy_init(self, config, layer_idx, expert_idx): def lazy_init(self, config, layer_idx, expert_idx):
checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors" comfyui_dir = Path.home() / "ComfyUI"
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
checkpoint = checkpoint.resolve()
if not os.path.exists(checkpoint): if not os.path.exists(checkpoint):
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}") raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
@ -720,34 +723,65 @@ class HunyuanMoE(nn.Module):
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx) self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
combined_output = torch.zeros_like(reshaped_input) combined_output = torch.zeros_like(reshaped_input)
experts_list = []
for e in range(self.num_experts): for e in range(self.num_experts):
token_mask = (expert_index == e) token_mask = (expert_index == e)
if not token_mask.any(): if not token_mask.any():
continue continue
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
expert = expert.lazy_init(self.config, self.layer_idx, e)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
experts_list.append((e, expert))
per_pos, per_tokens, per_weights = [], [], []
for e, _ in experts_list:
token_mask = (expert_index == e)
token_ids = token_mask.nonzero(as_tuple=False) token_ids = token_mask.nonzero(as_tuple=False)
token_positions = token_ids[:, 0] token_positions = token_ids[:, 0]
topk_slot = token_ids[:, 1] topk_slot = token_ids[:, 1]
tokens = reshaped_input[token_positions] tokens = reshaped_input[token_positions]
weights = expert_weight[token_positions, topk_slot] weights = expert_weight[token_positions, topk_slot]
if self.experts is not None and INIT_MOE: per_pos.append(token_positions)
out = self.experts[e](tokens) per_tokens.append(tokens)
elif self.experts is None: per_weights.append(weights)
expert = self.moe_lru.get_from_device(e + self.layer_idx)
if expert is None:
expert = LazyMoELoader()
out = expert.lazy_init(self.config, self.layer_idx, e)(tokens)
self.moe_lru.add_gpu(expert, e + self.layer_idx)
else:
tokens = tokens.to(next(expert.parameters()).device)
out = expert(tokens.view(bsz, -1, hidden_size))
out = out * weights.to(out.device).unsqueeze(-1) lengths = [t.shape[0] for t in per_tokens]
E = len(per_tokens)
L = max(lengths)
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
for i, t in enumerate(per_tokens):
tokens_padded[i, : t.shape[0]] = t
weights_padded[i, : t.shape[0]] = per_weights[i]
l1, l2 = [], []
for _, expert in experts_list:
l1.append(expert.gate_and_up_proj)
l2.append(expert.down_proj)
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
W1_T = W1.transpose(1, 2)
W2_T = W2.transpose(1, 2)
x = torch.bmm(tokens_padded, W1_T)
x = F.silu(x)
out_padded = torch.bmm(x, W2_T)
out_padded = out_padded * weights_padded.unsqueeze(-1)
for i, token_positions in enumerate(per_pos):
Ni = lengths[i]
out_i = out_padded[i, :Ni]
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size))
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) #dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
#chunks = dispatched_input.chunk(self.num_experts, dim=0) #chunks = dispatched_input.chunk(self.num_experts, dim=0)
#expert_outputs = [] #expert_outputs = []
@ -1014,12 +1048,12 @@ class HunyuanImage3ForCausalMM(nn.Module):
dtype=x.dtype, dtype=x.dtype,
) )
image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool) image_mask = torch.ones(x.size(1))
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
gen_timestep_scatter_index = 4 gen_timestep_scatter_index = 4
cond, uncond = condition[:4], condition[4:] cond, uncond = condition[:4], condition[4:]
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
height, width = x.shape[2] * 16, x.shape[3] * 16 height, width = x.shape[2] * 16, x.shape[3] * 16
token_height = height // (16 * 16) token_height = height // (16 * 16)

View File

@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.hunyuan_image_3.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -1197,6 +1198,10 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out return out
class HunyuanImage3(BaseModel):
def __init__(self, model_config, model_type=ModelType.Flow, device=None):
super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM)
class HiDream(BaseModel): class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)