mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 10:20:16 +08:00
some fixes
This commit is contained in:
parent
a58133f188
commit
3dd39efa03
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
@ -590,6 +589,8 @@ class LazyMoELoader(nn.Module):
|
|||||||
self.config,
|
self.config,
|
||||||
layer_idx=layer,
|
layer_idx=layer,
|
||||||
device="meta",
|
device="meta",
|
||||||
|
dtype = self.dtype,
|
||||||
|
operations = self.operations
|
||||||
)
|
)
|
||||||
return pool
|
return pool
|
||||||
|
|
||||||
@ -720,7 +721,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
|
|
||||||
for i in used_indices:
|
for i in used_indices:
|
||||||
expert = self.experts[i]
|
expert = self.experts[i]
|
||||||
if isinstance(expert, asyncio.Task) or isinstance(expert, asyncio.Future):
|
if isinstance(expert, concurrent.futures.Future) or isinstance(expert, asyncio.Future):
|
||||||
if expert.done():
|
if expert.done():
|
||||||
self.experts[i] = expert.result()
|
self.experts[i] = expert.result()
|
||||||
ready_indices.append(i)
|
ready_indices.append(i)
|
||||||
@ -733,7 +734,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
pending_pos = [used_indices.index(i) for i in pending_indices]
|
pending_pos = [used_indices.index(i) for i in pending_indices]
|
||||||
|
|
||||||
if ready_indices:
|
if ready_indices:
|
||||||
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], asyncio.Task) or isinstance(self.experts[i], asyncio.Future))
|
ready_experts = [self.experts[i] if not (isinstance(self.experts[i], concurrent.futures.Future) or isinstance(self.experts[i], asyncio.Future))
|
||||||
else self.experts[i].result()
|
else self.experts[i].result()
|
||||||
for i in ready_indices]
|
for i in ready_indices]
|
||||||
tokens_for_ready = tokens_padded[ready_pos]
|
tokens_for_ready = tokens_padded[ready_pos]
|
||||||
@ -757,6 +758,7 @@ class HunyuanMoE(nn.Module):
|
|||||||
out_list_ordered = [out_parts[i] for i in used_indices]
|
out_list_ordered = [out_parts[i] for i in used_indices]
|
||||||
out_padded_all = torch.cat(out_list_ordered, dim=0)
|
out_padded_all = torch.cat(out_list_ordered, dim=0)
|
||||||
|
|
||||||
|
combine_weights = combine_weights.to(out_padded_all.dtype)
|
||||||
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
|
combined_output = torch.einsum("suc,uco->so", combine_weights, out_padded_all)
|
||||||
|
|
||||||
del out_padded_all, out_list_ordered, out_parts
|
del out_padded_all, out_list_ordered, out_parts
|
||||||
@ -794,7 +796,7 @@ class HunyuanImage3Attention(nn.Module):
|
|||||||
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
self.o_proj = operations.Linear(self.hidden_size_q, self.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||||
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"]), device=device, dtype=dtype
|
self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
@ -905,7 +907,7 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
self.wte = operations.Embedding(133120, config["hidden_size"], self.padding_idx, device=device, dtype=dtype)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, devuce=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru, device=device, dtype=dtype, operations=operations) for layer_idx in range(config["num_hidden_layers"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
self.ln_f = HunyuanRMSNorm(config["hidden_size"], eps=config["rms_norm_eps"])
|
||||||
@ -914,6 +916,8 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
self.moe_lru = moe_lru
|
self.moe_lru = moe_lru
|
||||||
self.additional_layers_set = False
|
self.additional_layers_set = False
|
||||||
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
self.moe_loader = LazyMoELoader(self.moe_lru, self.config)
|
||||||
|
self.moe_loader.operations = operations
|
||||||
|
self.moe_loader.dtype = dtype
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -996,7 +1000,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
factory_kwargs = {"device": device, "dtype": dtype, "operations": operations}
|
||||||
|
|
||||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"], **factory_kwargs)
|
||||||
self.patch_embed = UNetDown(
|
self.patch_embed = UNetDown(
|
||||||
patch_size=1,
|
patch_size=1,
|
||||||
emb_channels=config["hidden_size"],
|
emb_channels=config["hidden_size"],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user