fixed some mistakes/errors

This commit is contained in:
Yousef Rafat 2025-11-25 21:34:03 +02:00
parent b4001bbd27
commit 823870db53
3 changed files with 96 additions and 112 deletions

View File

@ -493,42 +493,44 @@ class MoELRUCache(nn.Module):
self.last_offload_event = None
self._loop = asyncio.new_event_loop()
self._gpu_sem = asyncio.Semaphore(1) # maybe 2
threading.Thread(target=self._loop.run_forever, daemon=True).start()
async def _async_offload_to_cpu(self, layer_idx):
# async offload from gpu (removed)
num_experts = 64
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
for i in range(num_experts)
if (layer_idx * num_experts + i) in self.gpu_cache]
event = torch.cuda.Event()
with torch.cuda.stream(self.offload_stream):
for index, moe in moe_group:
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
if p_gpu.device.type == "meta":
continue
with torch.no_grad():
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
p_cpu.copy_(p_gpu, non_blocking=True)
async with self._gpu_sem:
num_experts = 64
moe_group = [(layer_idx * num_experts + i, self.gpu_cache[layer_idx * num_experts + i])
for i in range(num_experts)
if (layer_idx * num_experts + i) in self.gpu_cache]
event = torch.cuda.Event()
with torch.cuda.stream(self.offload_stream):
for index, moe in moe_group:
moe_cpu = HunyuanMLP(moe.config).to("cpu", non_blocking=True)
for (name, p_gpu), p_cpu in zip(moe.named_parameters(), moe_cpu.parameters()):
if p_gpu.device.type == "meta":
continue
with torch.no_grad():
p_cpu.data = torch.empty_like(p_gpu, device="cpu", pin_memory=True)
p_cpu.copy_(p_gpu, non_blocking=True)
self.cpu_cache[index] = moe_cpu
self.cpu_cache[index] = moe_cpu
self.offload_stream.record_event(event)
self.offload_stream.record_event(event)
self.last_offload_event = event
self.last_offload_event = event
def finalize_offload_layer():
event.synchronize()
for index, moe in moe_group:
moe.to("meta")
self.gpu_cache.pop(index, None)
del moe
torch.cuda.empty_cache()
def finalize_offload_layer():
event.synchronize()
for index, moe in moe_group:
moe.to("meta")
self.gpu_cache.pop(index, None)
del moe
torch.cuda.empty_cache()
threading.Thread(target=finalize_offload_layer, daemon=True).start()
threading.Thread(target=finalize_offload_layer, daemon=True).start()
async def _async_load_to_gpu(self, index, moe):
@ -632,37 +634,29 @@ class LazyMoELoader(nn.Module):
getattr(model, name).data = tensor
return model
def _register_expert_sync(self, layer_idx, expert_idx, moe_cpu):
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
async def lazy_load_from_disk(self, layer_idx, expert_idx):
loop = asyncio.get_event_loop()
async with self._semaphore:
moe_cpu = await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
self._loop.call_soon_threadsafe(self._register_expert_sync, layer_idx, expert_idx, moe_cpu)
return moe_cpu
async def schedule_layer_load_progressive(self, layer_idx, num_experts = 64):
tasks = [asyncio.create_task(self.lazy_load_from_disk(layer_idx, i)) for i in range(num_experts)]
results = await asyncio.gather(*tasks, return_exceptions=False)
return results
def schedule_layer_load(self, layer_idx, num_experts = 64):
fut = asyncio.run_coroutine_threadsafe(
self.schedule_layer_load_progressive(layer_idx, num_experts),
self._loop
)
return fut
async def schedule_layer_load_(self, layer_idx):
tasks = [self.lazy_load_from_disk(layer_idx, i) for i in range(64)]
experts = await asyncio.gather(*tasks)
return experts
return await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx)
def _schedule_disk_load(self, layer_idx, expert_idx):
coro = self.lazy_load_from_disk(layer_idx, expert_idx)
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
def _on_disk_loaded(fut):
moe_cpu = fut.result()
def _add_cpu_in_main_thread():
self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx)
asyncio.run_coroutine_threadsafe(
self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu),
self.cache._loop
)
threading.Thread(target=_add_cpu_in_main_thread, daemon=True).start()
future.add_done_callback(_on_disk_loaded)
return future
def enough_vram(required_bytes):
free, total = torch.cuda.mem_get_info()
return free > required_bytes
@ -748,7 +742,6 @@ class HunyuanMoE(nn.Module):
if event is not None and not event.query():
time.sleep(0.001)
combine_weights_used = combine_weights[:, used_indices, :]
combined_output = torch.einsum("suc,ucm->sm",
@ -900,6 +893,7 @@ class HunyuanImage3Model(nn.Module):
self.padding_idx = 128009
self.vocab_size = 133120
self.config = config
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
self.layers = nn.ModuleList(
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
)
@ -932,20 +926,27 @@ class HunyuanImage3Model(nn.Module):
sparse_interval = max(1, len(self.layers) // additional_layers)
if len(self.layers[0].mlp.experts) == 0:
self.layers[0].mlp.experts = self.moe_loader.schedule_layer_load(0)
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[0].mlp.experts = [expert._schedule_disk_load(0, i) for i, expert in enumerate(experts)]
for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx + 1 < len(self.layers) and len(self.layers[layer_idx + 1].mlp.experts) == 0: # not loaded
self.layers[layer_idx+1].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 1)
# maybe the second layer loading should depend on how much gpu memory is there
next_layer = layer_idx + 1 if isinstance(self.layers[layer_idx + 1].mlp.experts, list) else layer_idx + 2
second_next_layer = next_layer + 1 if isinstance(self.layers[layer_idx + 2].mlp.experts, list) else next_layer + 2
if layer_idx + 2 < len(self.layers) and len(self.layers[layer_idx + 2].mlp.experts) == 0: # load first and second layers
self.layers[layer_idx+2].mlp.experts = self.moe_loader.schedule_layer_load(layer_idx + 2)
if not self.additional_layers_set:
if (layer_idx % sparse_interval == 0) and layer_idx >= sparse_interval:
self.layers[next_layers].mlp.experts = self.moe_loader.schedule_layer_load(next_layers)
next_layers += 1
if next_layer < len(self.layers) and len(self.layers[next_layer].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layer].mlp.experts = [expert._schedule_disk_load(next_layer, i) for i, expert in enumerate(experts)]
if second_next_layer < len(self.layers) and len(self.layers[second_next_layer].mlp.experts) == 0: # not loaded
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[second_next_layer].mlp.experts = [expert._schedule_disk_load(second_next_layer, i) for i, expert in enumerate(experts)]
if (layer_idx % sparse_interval == 0) and layer_idx > sparse_interval:
experts = [LazyMoELoader(self.moe_lru, self.config) for _ in range(64)]
self.layers[next_layers].mlp.experts = [expert._schedule_disk_load(next_layers, i) for i, expert in enumerate(experts)]
next_layers += 1
with torch.no_grad():
layer_outputs = decoder_layer(
@ -1020,7 +1021,8 @@ class HunyuanImage3ForCausalMM(nn.Module):
self.first_step = True
self.kv_cache = None
self.token_dims = ()
self.encode_tok = None
self.special_tok = None
@staticmethod
def get_pos_emb(custom_pos_emb, position_ids):
@ -1043,24 +1045,36 @@ class HunyuanImage3ForCausalMM(nn.Module):
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
bsz, seq_len, n_embd = inputs_embeds.shape
cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any()
height, width = x.size(2) * 16, x.size(3) * 16
gen_timestep_scatter_index = 4
def fn(string, func = self.encode_tok):
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
.unsqueeze(0).expand(bsz, -1, -1)
if cond_exists:
with torch.no_grad():
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
if self.first_step:
token_height, token_width = x[:, -2:, 0].tolist()[0]
self.token_dims = (int(token_height), int(token_width))
x = x[:, :-2, :]
else:
token_height, token_width = self.token_dims
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio
img_slices = []
bsz, seq_len, n_embd = inputs_embeds.shape
cond_timestep = torch.zeros(x.size(0))
t_emb = self.time_embed(timestep)
if self.first_step:
x, token_height, token_width = self.patch_embed(x, t_emb)
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
else:
x, token_height, token_width = self.patch_embed(x, t_emb)
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
for i in range(x.size(0)):
gen_offset = seq_len + x.size(1)
if cond_exists:
@ -1085,27 +1099,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
rope_img = [(img_s[0], (token_height, token_width))]
rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
cond_timestep = torch.zeros(inputs_embeds.size(0))
t_emb = self.time_embed(cond_timestep)
if self.first_step:
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
else:
t_emb = self.time_embed(timestep)
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
inputs_embeds = torch.cat([timestep_emb, x], dim=1)
input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds]
#/////////////
# cond_vae_images
# cond_timestep_scatter_index
if cond_exists:
with torch.no_grad():
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
joint_image[:, 3:4, :] = self.timestep_emb(cond_timestep.reshape(-1)).reshape(bsz, -1, n_embd)
inputs_embeds = torch.cat([*input_args, joint_image], dim = 1)
else:

View File

@ -1342,6 +1342,11 @@ class HunyuanImage3(supported_models_base.BASE):
state_dict["text_encoders.wte"] = state_dict["model.model.wte"]
state_dict.pop("model.model.wte", None)
model = model_base.HunyuanImage3(self, device = device)
temp_tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer()
model.encode_tok = temp_tokenizer.tokenizer.convert_tokens_to_ids
model.special_tok = temp_tokenizer.tokenizer.added_tokens_encoder
return model
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImage3Tokenizer, comfy.text_encoders.hunyuan_image.HunyuanImage3)

View File

@ -22,34 +22,13 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
io.Int.Input("height", min = 1, default = 512),
io.Int.Input("width", min = 1, default = 512),
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
io.Clip.Input("clip"),
io.Model.Input("model")
],
outputs=[io.Latent.Output(display_name="latent")]
)
@classmethod
def execute(cls, height, width, batch_size, clip, model):
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
word_embed = clip.wte
patch_embed = model.patch_embed
t_embed = model.time_embed
def execute(cls, height, width, batch_size):
height, width = get_target_size(height, width)
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size)))
def tk_fn(token):
return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1))
def fn(string, func = encode_fn):
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
.unsqueeze(0).expand(batch_size, -1, -1)
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1)
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
class HunyuanImage3Conditioning(io.ComfyNode):