mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
vectorized implementation of moe/fixes for issues
This commit is contained in:
parent
de43880bdb
commit
a2fff60d4c
@ -4,6 +4,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import psutil
|
import psutil
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -460,13 +461,13 @@ class HunyuanMLP(nn.Module):
|
|||||||
self.intermediate_size = 3072
|
self.intermediate_size = 3072
|
||||||
|
|
||||||
self.act_fn = torch.nn.functional.silu
|
self.act_fn = torch.nn.functional.silu
|
||||||
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = x.unsqueeze(0)
|
|
||||||
self.intermediate_size *= 2 # SwiGLU
|
self.intermediate_size *= 2 # SwiGLU
|
||||||
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
|
self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=False)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
self.gate_and_up_proj, self.down_proj = self.gate_and_up_proj.to(x.device), self.down_proj.to(x.device)
|
||||||
|
if x.ndim == 2:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
gate_and_up_proj = self.gate_and_up_proj(x)
|
gate_and_up_proj = self.gate_and_up_proj(x)
|
||||||
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
x1, x2 = gate_and_up_proj.chunk(2, dim=2)
|
||||||
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
down_proj = self.down_proj(x1 * self.act_fn(x2))
|
||||||
@ -654,7 +655,9 @@ class LazyMoELoader(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def lazy_init(self, config, layer_idx, expert_idx):
|
def lazy_init(self, config, layer_idx, expert_idx):
|
||||||
checkpoint = "./models/checkpoint/hunyuan_image_3.safetensors"
|
comfyui_dir = Path.home() / "ComfyUI"
|
||||||
|
checkpoint = comfyui_dir / "models" / "checkpoint" / "hunyuan_image_3.safetensors"
|
||||||
|
checkpoint = checkpoint.resolve()
|
||||||
if not os.path.exists(checkpoint):
|
if not os.path.exists(checkpoint):
|
||||||
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
raise ValueError(f"Hunyuan Image 3 Checkpoint on one GPU should have the path: {checkpoint}")
|
||||||
|
|
||||||
@ -720,34 +723,65 @@ class HunyuanMoE(nn.Module):
|
|||||||
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
|
self.moe_lru.add_cpu(expert_cpu, expert_id + self.layer_idx)
|
||||||
|
|
||||||
combined_output = torch.zeros_like(reshaped_input)
|
combined_output = torch.zeros_like(reshaped_input)
|
||||||
|
experts_list = []
|
||||||
for e in range(self.num_experts):
|
for e in range(self.num_experts):
|
||||||
token_mask = (expert_index == e)
|
token_mask = (expert_index == e)
|
||||||
if not token_mask.any():
|
if not token_mask.any():
|
||||||
continue
|
continue
|
||||||
|
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
||||||
|
if expert is None:
|
||||||
|
expert = LazyMoELoader()
|
||||||
|
expert = expert.lazy_init(self.config, self.layer_idx, e)
|
||||||
|
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
||||||
|
experts_list.append((e, expert))
|
||||||
|
|
||||||
|
per_pos, per_tokens, per_weights = [], [], []
|
||||||
|
for e, _ in experts_list:
|
||||||
|
token_mask = (expert_index == e)
|
||||||
|
|
||||||
token_ids = token_mask.nonzero(as_tuple=False)
|
token_ids = token_mask.nonzero(as_tuple=False)
|
||||||
token_positions = token_ids[:, 0]
|
token_positions = token_ids[:, 0]
|
||||||
|
|
||||||
topk_slot = token_ids[:, 1]
|
topk_slot = token_ids[:, 1]
|
||||||
|
|
||||||
tokens = reshaped_input[token_positions]
|
tokens = reshaped_input[token_positions]
|
||||||
weights = expert_weight[token_positions, topk_slot]
|
weights = expert_weight[token_positions, topk_slot]
|
||||||
|
|
||||||
if self.experts is not None and INIT_MOE:
|
per_pos.append(token_positions)
|
||||||
out = self.experts[e](tokens)
|
per_tokens.append(tokens)
|
||||||
elif self.experts is None:
|
per_weights.append(weights)
|
||||||
expert = self.moe_lru.get_from_device(e + self.layer_idx)
|
|
||||||
if expert is None:
|
|
||||||
expert = LazyMoELoader()
|
|
||||||
out = expert.lazy_init(self.config, self.layer_idx, e)(tokens)
|
|
||||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
|
||||||
else:
|
|
||||||
tokens = tokens.to(next(expert.parameters()).device)
|
|
||||||
out = expert(tokens.view(bsz, -1, hidden_size))
|
|
||||||
|
|
||||||
out = out * weights.to(out.device).unsqueeze(-1)
|
lengths = [t.shape[0] for t in per_tokens]
|
||||||
|
E = len(per_tokens)
|
||||||
|
L = max(lengths)
|
||||||
|
tokens_padded = torch.zeros((E, L, hidden_size), device=hidden_states.device, dtype=reshaped_input.dtype)
|
||||||
|
weights_padded = torch.zeros((E, L), device=hidden_states.device, dtype=per_weights[0].dtype)
|
||||||
|
for i, t in enumerate(per_tokens):
|
||||||
|
tokens_padded[i, : t.shape[0]] = t
|
||||||
|
weights_padded[i, : t.shape[0]] = per_weights[i]
|
||||||
|
|
||||||
|
l1, l2 = [], []
|
||||||
|
for _, expert in experts_list:
|
||||||
|
l1.append(expert.gate_and_up_proj)
|
||||||
|
l2.append(expert.down_proj)
|
||||||
|
|
||||||
|
W1 = torch.stack([l.weight for l in l1]).to(hidden_states.device)
|
||||||
|
W2 = torch.stack([l.weight for l in l2]).to(hidden_states.device)
|
||||||
|
|
||||||
|
W1_T = W1.transpose(1, 2)
|
||||||
|
W2_T = W2.transpose(1, 2)
|
||||||
|
|
||||||
|
x = torch.bmm(tokens_padded, W1_T)
|
||||||
|
x = F.silu(x)
|
||||||
|
|
||||||
|
out_padded = torch.bmm(x, W2_T)
|
||||||
|
|
||||||
|
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||||
|
|
||||||
|
for i, token_positions in enumerate(per_pos):
|
||||||
|
Ni = lengths[i]
|
||||||
|
out_i = out_padded[i, :Ni]
|
||||||
|
combined_output.index_add_(0, token_positions.to(hidden_states.device), out_i)
|
||||||
|
|
||||||
combined_output.to(out.device).index_add_(0, token_positions.to(out.device), out.reshape(-1, hidden_size))
|
|
||||||
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
#dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
|
||||||
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
#chunks = dispatched_input.chunk(self.num_experts, dim=0)
|
||||||
#expert_outputs = []
|
#expert_outputs = []
|
||||||
@ -1014,12 +1048,12 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_mask = torch.arange(1, x.size(1) - 1).to(torch.bool)
|
image_mask = torch.ones(x.size(1))
|
||||||
|
image_mask[:, :5] = torch.zeros(5); image_mask[:, -4:] = torch.zeros(4)
|
||||||
gen_timestep_scatter_index = 4
|
gen_timestep_scatter_index = 4
|
||||||
cond, uncond = condition[:4], condition[4:]
|
cond, uncond = condition[:4], condition[4:]
|
||||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||||
|
|
||||||
|
|
||||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||||
token_height = height // (16 * 16)
|
token_height = height // (16 * 16)
|
||||||
|
|||||||
@ -42,6 +42,7 @@ import comfy.ldm.hidream.model
|
|||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
|
import comfy.ldm.hunyuan_image_3.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1197,6 +1198,10 @@ class Hunyuan3Dv2(BaseModel):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanImage3(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.Flow, device=None):
|
||||||
|
super().__init__(model_config, model_type, device, unet_model = comfy.ldm.hunyuan_image_3.model.HunyuanImage3ForCausalMM)
|
||||||
|
|
||||||
class HiDream(BaseModel):
|
class HiDream(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user