some fixes

This commit is contained in:
Yousef Rafat 2025-12-02 23:40:31 +02:00
parent a58133f188
commit 3dd39efa03

View File

@ -1,6 +1,5 @@
import os
import math
import time
import torch
import asyncio
import threading
@ -590,6 +589,8 @@ class LazyMoELoader(nn.Module):
self.config,
layer_idx=layer,
device="meta",
dtype = self.dtype,
operations = self.operations
)
return pool
@ -720,7 +721,7 @@ class HunyuanMoE(nn.Module):
for i in used_indices:
expert = self.experts[i]
if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future):
if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future):
if expert.done():
self.experts[i] = expert.result()
ready_indices.append(i)
@ -733,7 +734,7 @@ class HunyuanMoE(nn.Module):
pending_pos = [used_indices.index(i) for i in pending_indices]
if ready_indices:
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future))
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future))
else self.experts[i].result()
for i in ready_indices]
tokens_for_ready = tokens_padded[ready_pos]
@ -757,6 +758,7 @@ class HunyuanMoE(nn.Module):
out_list_ordered = [out_parts[i] for i in used_indices]
out_padded_all = torch.cat(out_list_ordered, dim=0)
combine_weights = combine_weights.to(out_padded_all.dtype)
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
del out_padded_all, out_list_ordered, out_parts
@ -794,7 +796,7 @@ class HunyuanImage3Attention(nn.Module):
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -905,7 +907,7 @@ class HunyuanImage3Model(nn.Module):
self.config = config
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
)
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
@ -914,6 +916,8 @@ class HunyuanImage3Model(nn.Module):
self.moe_lru = moe_lru
self.additional_layers_set = False
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
self.moe_loader.operations = operations
self.moe_loader.dtype = dtype
def forward(
self,
@ -996,7 +1000,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
self.config = config
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
self.patch_embed = UNetDown(
patch_size=1,
emb_channels=config["hidden_size"],