mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 12:20:16 +08:00
updated the hunyuan moe forward
splitted the forward statement between ready and pending experts
This commit is contained in:
parent
76e14d69b2
commit
a58133f188
@ -669,6 +669,8 @@ class HunyuanMoE(nn.Module):
|
|||||||
self.moe_lru = moe_lru
|
self.moe_lru = moe_lru
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
|
# do the forward statement over the already loaded experts to give time for the pending experts
|
||||||
|
# makes the gpu not sit idle
|
||||||
if not INIT_MOE:
|
if not INIT_MOE:
|
||||||
torch.cuda.set_device(0)
|
torch.cuda.set_device(0)
|
||||||
else:
|
else:
|
||||||
@ -696,50 +698,68 @@ class HunyuanMoE(nn.Module):
|
|||||||
else:
|
else:
|
||||||
tokens_padded = dispatched_input[used_indices]
|
tokens_padded = dispatched_input[used_indices]
|
||||||
|
|
||||||
l1, l2 = [], []
|
def compute_expert_outputs(experts_list, tokens_padded, device):
|
||||||
|
l1, l2 = [], []
|
||||||
|
for m in experts_list:
|
||||||
|
l1.append(m.gate_and_up_proj)
|
||||||
|
l2.append(m.down_proj)
|
||||||
|
|
||||||
|
W1 = torch.stack([m.weight.to(device) for m in l1], dim=0)
|
||||||
|
W1_T = W1.transpose(1, 2)
|
||||||
|
x = torch.bmm(tokens_padded.to(device), W1_T)
|
||||||
|
x1, x2 = x.chunk(2, dim=2)
|
||||||
|
gated = x1 * F.silu(x2)
|
||||||
|
|
||||||
|
W2 = torch.stack([m.weight.to(device) for m in l2], dim=0)
|
||||||
|
W2_T = W2.transpose(1, 2)
|
||||||
|
out_padded = torch.bmm(gated, W2_T)
|
||||||
|
return out_padded
|
||||||
|
|
||||||
|
out_parts = {}
|
||||||
|
ready_indices, pending_indices, pending_futures = [], [], {}
|
||||||
|
|
||||||
for i in used_indices:
|
for i in used_indices:
|
||||||
expert = self.experts[i]
|
expert = self.experts[i]
|
||||||
if isinstance(expert, (asyncio.Future, concurrent.futures.Future)):
|
if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future):
|
||||||
expert = expert.result()
|
if expert.done():
|
||||||
expert = expert.to(device)
|
self.experts[i] = expert.result()
|
||||||
l1.append(expert.gate_and_up_proj)
|
ready_indices.append(i)
|
||||||
l2.append(expert.down_proj)
|
else:
|
||||||
|
pending_indices.append(i)
|
||||||
|
pending_futures[i] = expert
|
||||||
|
else:
|
||||||
|
ready_indices.append(i)
|
||||||
|
ready_pos = [used_indices.index(i) for i in ready_indices]
|
||||||
|
pending_pos = [used_indices.index(i) for i in pending_indices]
|
||||||
|
|
||||||
compute_device = hidden_states.device
|
if ready_indices:
|
||||||
|
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future))
|
||||||
|
else self.experts[i].result()
|
||||||
|
for i in ready_indices]
|
||||||
|
tokens_for_ready = tokens_padded[ready_pos]
|
||||||
|
out_ready = compute_expert_outputs(ready_experts, tokens_for_ready, device)
|
||||||
|
for idx_pos, expert_idx in enumerate(ready_indices):
|
||||||
|
out_parts[expert_idx] = out_ready[idx_pos:idx_pos+1]
|
||||||
|
|
||||||
l1 = [m.to(compute_device) for m in l1]
|
for i in pending_indices:
|
||||||
W1 = torch.stack([m.weight for m in l1], dim=0)
|
expert = self.experts[i]
|
||||||
del l1
|
if isinstance(expert, asyncio.Future):
|
||||||
W1_T = W1.transpose(1, 2)
|
loaded_expert = expert.result()
|
||||||
|
self.experts[i] = loaded_expert
|
||||||
|
|
||||||
del W1
|
if pending_indices:
|
||||||
x = torch.bmm(tokens_padded, W1_T)
|
pending_experts = [self.experts[i] for i in pending_indices]
|
||||||
del W1_T, tokens_padded
|
tokens_for_pending = tokens_padded[pending_pos]
|
||||||
|
out_pending = compute_expert_outputs(pending_experts, tokens_for_pending, device)
|
||||||
|
for idx_pos, expert_idx in enumerate(pending_indices):
|
||||||
|
out_parts[expert_idx] = out_pending[idx_pos:idx_pos+1]
|
||||||
|
|
||||||
x1, x2 = x.chunk(2, dim=2)
|
out_list_ordered = [out_parts[i] for i in used_indices]
|
||||||
gated = x1 * F.silu(x2)
|
out_padded_all = torch.cat(out_list_ordered, dim=0)
|
||||||
|
|
||||||
l2 = [m.to(compute_device) for m in l2]
|
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
|
||||||
W2 = torch.stack([m.weight for m in l2], dim=0)
|
|
||||||
del l2
|
del out_padded_all, out_list_ordered, out_parts
|
||||||
W2_T = W2.transpose(1, 2)
|
|
||||||
del W2
|
|
||||||
out_padded = torch.bmm(gated, W2_T)
|
|
||||||
del W2_T
|
|
||||||
|
|
||||||
while not enough_vram(3*(1024 ** 3)):
|
|
||||||
event = self.moe_lru.last_offload_event
|
|
||||||
if event is not None and not event.query():
|
|
||||||
time.sleep(0.001)
|
|
||||||
|
|
||||||
combine_weights_used = combine_weights[:, used_indices, :]
|
|
||||||
|
|
||||||
combined_output = torch.einsum("suc,ucm->sm",
|
|
||||||
combine_weights_used.type_as(out_padded),
|
|
||||||
out_padded
|
|
||||||
)
|
|
||||||
|
|
||||||
del x, x1, x2, gated, out_padded
|
|
||||||
|
|
||||||
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
|
||||||
|
|
||||||
|
|||||||
@ -1339,7 +1339,7 @@ class HunyuanImage3(supported_models_base.BASE):
|
|||||||
latent_format = latent_formats.HunyuanImage3
|
latent_format = latent_formats.HunyuanImage3
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
state_dict["text_encoders.wte"] = state_dict["model.model.wte"]
|
self.wte_sd = state_dict["model.model.wte"]
|
||||||
state_dict.pop("model.model.wte", None)
|
state_dict.pop("model.model.wte", None)
|
||||||
model = model_base.HunyuanImage3(self, device = device)
|
model = model_base.HunyuanImage3(self, device = device)
|
||||||
|
|
||||||
@ -1349,6 +1349,8 @@ class HunyuanImage3(supported_models_base.BASE):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
|
clip = comfy.text_encoders.hunyuan_image.HunyuanImage3
|
||||||
|
clip.embed_wte = self.wte_sd
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3)
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3)
|
||||||
|
|
||||||
class HunyuanImage21(HunyuanVideo):
|
class HunyuanImage21(HunyuanVideo):
|
||||||
|
|||||||
@ -7,9 +7,11 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
class HunyuanImage3TextEncoder(torch.nn.Module):
|
class HunyuanImage3TextEncoder(torch.nn.Module):
|
||||||
|
embed_wte = None
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wte = torch.nn.Embedding(133120, 4096, 128009)
|
self.wte = torch.nn.Embedding(133120, 4096, padding_idx = 128009)
|
||||||
|
self.wte.data = self.embed_wte
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.wte(x)
|
out = self.wte(x)
|
||||||
return out, torch.empty_like(out)
|
return out, torch.empty_like(out)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user