mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
some fixes
This commit is contained in:
parent
a58133f188
commit
3dd39efa03
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import asyncio
|
||||
import threading
|
||||
@ -590,6 +589,8 @@ class LazyMoELoader(nn.Module):
|
||||
self.config,
|
||||
layer_idx=layer,
|
||||
device="meta",
|
||||
dtype = self.dtype,
|
||||
operations = self.operations
|
||||
)
|
||||
return pool
|
||||
|
||||
@ -720,7 +721,7 @@ class HunyuanMoE(nn.Module):
|
||||
|
||||
for i in used_indices:
|
||||
expert = self.experts[i]
|
||||
if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future):
|
||||
if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future):
|
||||
if expert.done():
|
||||
self.experts[i] = expert.result()
|
||||
ready_indices.append(i)
|
||||
@ -733,7 +734,7 @@ class HunyuanMoE(nn.Module):
|
||||
pending_pos = [used_indices.index(i) for i in pending_indices]
|
||||
|
||||
if ready_indices:
|
||||
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future))
|
||||
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future))
|
||||
else self.experts[i].result()
|
||||
for i in ready_indices]
|
||||
tokens_for_ready = tokens_padded[ready_pos]
|
||||
@ -757,6 +758,7 @@ class HunyuanMoE(nn.Module):
|
||||
out_list_ordered = [out_parts[i] for i in used_indices]
|
||||
out_padded_all = torch.cat(out_list_ordered, dim=0)
|
||||
|
||||
combine_weights = combine_weights.to(out_padded_all.dtype)
|
||||
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
|
||||
|
||||
del out_padded_all, out_list_ordered, out_parts
|
||||
@ -794,7 +796,7 @@ class HunyuanImage3Attention(nn.Module):
|
||||
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
|
||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
@ -905,7 +907,7 @@ class HunyuanImage3Model(nn.Module):
|
||||
self.config = config
|
||||
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
||||
)
|
||||
|
||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||
@ -914,6 +916,8 @@ class HunyuanImage3Model(nn.Module):
|
||||
self.moe_lru = moe_lru
|
||||
self.additional_layers_set = False
|
||||
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
||||
self.moe_loader.operations = operations
|
||||
self.moe_loader.dtype = dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -996,7 +1000,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
self.config = config
|
||||
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
||||
|
||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||
self.patch_embed = UNetDown(
|
||||
patch_size=1,
|
||||
emb_channels=config["hidden_size"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user