From 3dd39efa03c4e5e7e883c9298e2eb2ecf47bdbe3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:40:31 +0200 Subject: [PATCH] some fixes --- comfy/ldm/hunyuan_image_3/model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 3518d1a0d..aa3dce524 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1,6 +1,5 @@ import os import math -import time import torch import asyncio import threading @@ -590,6 +589,8 @@ class LazyMoELoader(nn.Module): self.config, layer_idx=layer, device="meta", + dtype = self.dtype, + operations = self.operations ) return pool @@ -720,7 +721,7 @@ class HunyuanMoE(nn.Module): for i in used_indices: expert = self.experts[i] - if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future): + if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future): if expert.done(): self.experts[i] = expert.result() ready_indices.append(i) @@ -733,7 +734,7 @@ class HunyuanMoE(nn.Module): pending_pos = [used_indices.index(i) for i in pending_indices] if ready_indices: - ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future)) + ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future)) else self.experts[i].result() for i in ready_indices] tokens_for_ready = tokens_padded[ready_pos] @@ -757,6 +758,7 @@ class HunyuanMoE(nn.Module): out_list_ordered = [out_parts[i] for i in used_indices] out_padded_all = torch.cat(out_list_ordered, dim=0) + combine_weights = combine_weights.to(out_padded_all.dtype) combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all) del out_padded_all, out_list_ordered, out_parts @@ -794,7 +796,7 @@ class HunyuanImage3Attention(nn.Module): self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype) self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) - self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype + self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -905,7 +907,7 @@ class HunyuanImage3Model(nn.Module): self.config = config self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype) self.layers = nn.ModuleList( - [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])] + [HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])] ) self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) @@ -914,6 +916,8 @@ class HunyuanImage3Model(nn.Module): self.moe_lru = moe_lru self.additional_layers_set = False self.moe_loader = LazyMoELoader(self.moe_lru, self.config) + self.moe_loader.operations = operations + self.moe_loader.dtype = dtype def forward( self, @@ -996,7 +1000,7 @@ class HunyuanImage3ForCausalMM(nn.Module): self.config = config factory_kwargs = {"device": device, "dtype": dtype, "operations": operations} - self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) + self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs) self.patch_embed = UNetDown( patch_size=1, emb_channels=config["hidden_size"],