mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
vectorized implementation of moe/fixes for issues
This commit is contained in:
parent
de43880bdb
commit
a2fff60d4c
@ -4,6 +4,7 @@ import math
|
||||
import torch
|
||||
import psutil
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
@ -460,13 +461,13 @@ class HunyuanMLP(nn.Module):
|
||||
self.intermediate_size = 3072
|
||||
|
||||
self.act_fn = torch.nn.functional.silu
|
||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(0)
|
||||
self.intermediate_size *= 2 # SwiGLU
|
||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
|
||||
def forward(self, x):
|
||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(0)
|
||||
gate_and_up_proj = self.gate_and_up_proj(x)
|
||||
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
||||
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
||||
@ -654,7 +655,9 @@ class LazyMoELoader(nn.Module):
|
||||
self.device = device
|
||||
|
||||
def lazy_init(self, config, layer_idx, expert_idx):
|
||||
checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors"
|
||||
comfyui_dir = Path.home() / "ComfyUI"
|
||||
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
||||
checkpoint = checkpoint.resolve()
|
||||
if not os.path.exists(checkpoint):
|
||||
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
||||
|
||||
@ -720,34 +723,65 @@ class HunyuanMoE(nn.Module):
|
||||
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
|
||||
|
||||
combined_output = torch.zeros_like(reshaped_input)
|
||||
experts_list = []
|
||||
for e in range(self.num_experts):
|
||||
token_mask = (expert_index == e)
|
||||
if not token_mask.any():
|
||||
continue
|
||||
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
||||
if expert is None:
|
||||
expert = LazyMoELoader()
|
||||
expert = expert.lazy_init(self.config, self.layer_idx, e)
|
||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
||||
experts_list.append((e, expert))
|
||||
|
||||
per_pos, per_tokens, per_weights = [], [], []
|
||||
for e, _ in experts_list:
|
||||
token_mask = (expert_index == e)
|
||||
|
||||
token_ids = token_mask.nonzero(as_tuple=False)
|
||||
token_positions = token_ids[:, 0]
|
||||
|
||||
topk_slot = token_ids[:, 1]
|
||||
|
||||
tokens = reshaped_input[token_positions]
|
||||
weights = expert_weight[token_positions, topk_slot]
|
||||
|
||||
if self.experts is not None and INIT_MOE:
|
||||
out = self.experts[e](tokens)
|
||||
elif self.experts is None:
|
||||
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
||||
if expert is None:
|
||||
expert = LazyMoELoader()
|
||||
out = expert.lazy_init(self.config, self.layer_idx, e)(tokens)
|
||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
||||
else:
|
||||
tokens = tokens.to(next(expert.parameters()).device)
|
||||
out = expert(tokens.view(bsz, -1, hidden_size))
|
||||
per_pos.append(token_positions)
|
||||
per_tokens.append(tokens)
|
||||
per_weights.append(weights)
|
||||
|
||||
out = out * weights.to(out.device).unsqueeze(-1)
|
||||
lengths = [t.shape[0] for t in per_tokens]
|
||||
E = len(per_tokens)
|
||||
L = max(lengths)
|
||||
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
|
||||
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
|
||||
for i, t in enumerate(per_tokens):
|
||||
tokens_padded[i, : t.shape[0]] = t
|
||||
weights_padded[i, : t.shape[0]] = per_weights[i]
|
||||
|
||||
l1, l2 = [], []
|
||||
for _, expert in experts_list:
|
||||
l1.append(expert.gate_and_up_proj)
|
||||
l2.append(expert.down_proj)
|
||||
|
||||
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
|
||||
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
|
||||
|
||||
W1_T = W1.transpose(1, 2)
|
||||
W2_T = W2.transpose(1, 2)
|
||||
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
|
||||
out_padded = torch.bmm(x, W2_T)
|
||||
|
||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||
|
||||
for i, token_positions in enumerate(per_pos):
|
||||
Ni = lengths[i]
|
||||
out_i = out_padded[i, :Ni]
|
||||
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
|
||||
|
||||
combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size))
|
||||
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
||||
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
||||
#expert_outputs = []
|
||||
@ -1014,12 +1048,12 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool)
|
||||
image_mask = torch.ones(x.size(1))
|
||||
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
|
||||
gen_timestep_scatter_index = 4
|
||||
cond, uncond = condition[:4], condition[4:]
|
||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||
|
||||
|
||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||
token_height = height // (16 * 16)
|
||||
|
||||
@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.hunyuan_image_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1196,6 +1197,10 @@ class Hunyuan3Dv2(BaseModel):
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class HunyuanImage3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.Flow, device=None):
|
||||
super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM)
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user