diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 5a717bf38..3518d1a0d 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -669,6 +669,8 @@ class HunyuanMoE(nn.Module): self.moe_lru = moe_lru def forward(self, hidden_states): + # do the forward statement over the already loaded experts to give time for the pending experts + # makes the gpu not sit idle if not INIT_MOE: torch.cuda.set_device(0) else: @@ -696,50 +698,68 @@ class HunyuanMoE(nn.Module): else: tokens_padded = dispatched_input[used_indices] - l1, l2 = [], [] + def compute_expert_outputs(experts_list, tokens_padded, device): + l1, l2 = [], [] + for m in experts_list: + l1.append(m.gate_and_up_proj) + l2.append(m.down_proj) + + W1 = torch.stack([m.weight.to(device) for m in l1], dim=0) + W1_T = W1.transpose(1, 2) + x = torch.bmm(tokens_padded.to(device), W1_T) + x1, x2 = x.chunk(2, dim=2) + gated = x1 * F.silu(x2) + + W2 = torch.stack([m.weight.to(device) for m in l2], dim=0) + W2_T = W2.transpose(1, 2) + out_padded = torch.bmm(gated, W2_T) + return out_padded + + out_parts = {} + ready_indices, pending_indices, pending_futures = [], [], {} + for i in used_indices: expert = self.experts[i] - if isinstance(expert, (asyncio.Future, concurrent.futures.Future)): - expert = expert.result() - expert = expert.to(device) - l1.append(expert.gate_and_up_proj) - l2.append(expert.down_proj) + if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future): + if expert.done(): + self.experts[i] = expert.result() + ready_indices.append(i) + else: + pending_indices.append(i) + pending_futures[i] = expert + else: + ready_indices.append(i) + ready_pos = [used_indices.index(i) for i in ready_indices] + pending_pos = [used_indices.index(i) for i in pending_indices] - compute_device = hidden_states.device + if ready_indices: + ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future)) + else self.experts[i].result() + for i in ready_indices] + tokens_for_ready = tokens_padded[ready_pos] + out_ready = compute_expert_outputs(ready_experts, tokens_for_ready, device) + for idx_pos, expert_idx in enumerate(ready_indices): + out_parts[expert_idx] = out_ready[idx_pos:idx_pos+1] - l1 = [m.to(compute_device) for m in l1] - W1 = torch.stack([m.weight for m in l1], dim=0) - del l1 - W1_T = W1.transpose(1, 2) + for i in pending_indices: + expert = self.experts[i] + if isinstance(expert, asyncio.Future): + loaded_expert = expert.result() + self.experts[i] = loaded_expert - del W1 - x = torch.bmm(tokens_padded, W1_T) - del W1_T, tokens_padded + if pending_indices: + pending_experts = [self.experts[i] for i in pending_indices] + tokens_for_pending = tokens_padded[pending_pos] + out_pending = compute_expert_outputs(pending_experts, tokens_for_pending, device) + for idx_pos, expert_idx in enumerate(pending_indices): + out_parts[expert_idx] = out_pending[idx_pos:idx_pos+1] - x1, x2 = x.chunk(2, dim=2) - gated = x1 * F.silu(x2) + out_list_ordered = [out_parts[i] for i in used_indices] + out_padded_all = torch.cat(out_list_ordered, dim=0) - l2 = [m.to(compute_device) for m in l2] - W2 = torch.stack([m.weight for m in l2], dim=0) - del l2 - W2_T = W2.transpose(1, 2) - del W2 - out_padded = torch.bmm(gated, W2_T) - del W2_T - - while not enough_vram(3*(1024 ** 3)): - event = self.moe_lru.last_offload_event - if event is not None and not event.query(): - time.sleep(0.001) - - combine_weights_used = combine_weights[:, used_indices, :] - - combined_output = torch.einsum("suc,ucm->sm", - combine_weights_used.type_as(out_padded), - out_padded - ) - - del x, x1, x2, gated, out_padded + combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all) + + del out_padded_all, out_list_ordered, out_parts combined_output = combined_output.reshape(bsz, seq_len, hidden_size) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 43bc4b2e4..c9bd19363 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1339,7 +1339,7 @@ class HunyuanImage3(supported_models_base.BASE): latent_format = latent_formats.HunyuanImage3 def get_model(self, state_dict, prefix="", device=None): - state_dict["text_encoders.wte"] = state_dict["model.model.wte"] + self.wte_sd = state_dict["model.model.wte"] state_dict.pop("model.model.wte", None) model = model_base.HunyuanImage3(self, device = device) @@ -1349,6 +1349,8 @@ class HunyuanImage3(supported_models_base.BASE): return model def clip_target(self, state_dict={}): + clip = comfy.text_encoders.hunyuan_image.HunyuanImage3 + clip.embed_wte = self.wte_sd return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3) class HunyuanImage21(HunyuanVideo): diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index 732b3d80a..c67e08096 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -7,9 +7,11 @@ import os import re class HunyuanImage3TextEncoder(torch.nn.Module): + embed_wte = None def __init__(self): super().__init__() - self.wte = torch.nn.Embedding(133120, 4096, 128009) + self.wte = torch.nn.Embedding(133120, 4096, padding_idx = 128009) + self.wte.data = self.embed_wte def forward(self, x): out = self.wte(x) return out, torch.empty_like(out)