From b4001bbd27219a6752f4cdffc5ebb0e1678aaa77 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:25:27 +0200 Subject: [PATCH] input correction and improvements --- comfy/ldm/hunyuan_image_3/model.py | 61 ++++++++++++++++++++--------- comfy_extras/nodes_hunyuan_image.py | 9 +++-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 9a36f2ad4..750928f03 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -587,7 +587,7 @@ class LazyMoELoader(nn.Module): self.expert_pool = self.build_meta_experts() self._executor = ThreadPoolExecutor(max_workers=max_workers) - self._semaphore = threading.Semaphore(max_concurrent_loads) + self._semaphore = asyncio.Semaphore(max_concurrent_loads) def build_meta_experts(self): pool = {} @@ -632,17 +632,36 @@ class LazyMoELoader(nn.Module): getattr(model, name).data = tensor return model - def _load_single_expert(self, layer_idx, expert_idx): - with self._semaphore: - return self.lazy_init(layer_idx, expert_idx) + def _register_expert_sync(self, layer_idx, expert_idx, moe_cpu): + self.cache.add_cpu(moe_cpu, (layer_idx * 64) + expert_idx) + asyncio.run_coroutine_threadsafe( + self.cache._async_load_to_gpu((layer_idx * 64) + expert_idx, moe_cpu), + self.cache._loop + ) + + async def lazy_load_from_disk(self, layer_idx, expert_idx): + loop = asyncio.get_event_loop() + async with self._semaphore: + moe_cpu = await loop.run_in_executor(None, self.lazy_init, layer_idx, expert_idx) + self._loop.call_soon_threadsafe(self._register_expert_sync, layer_idx, expert_idx, moe_cpu) + return moe_cpu + + async def schedule_layer_load_progressive(self, layer_idx, num_experts = 64): + tasks = [asyncio.create_task(self.lazy_load_from_disk(layer_idx, i)) for i in range(num_experts)] + results = await asyncio.gather(*tasks, return_exceptions=False) + return results def schedule_layer_load(self, layer_idx, num_experts = 64): - futures = [] - for i in range(num_experts): - coro = asyncio.get_event_loop().run_in_executor(self._executor, self._load_single_expert, layer_idx, i) - fut = asyncio.run_coroutine_threadsafe(coro, self._loop) - futures.append(fut) - return futures + fut = asyncio.run_coroutine_threadsafe( + self.schedule_layer_load_progressive(layer_idx, num_experts), + self._loop + ) + return fut + + async def schedule_layer_load_(self, layer_idx): + tasks = [self.lazy_load_from_disk(layer_idx, i) for i in range(64)] + experts = await asyncio.gather(*tasks) + return experts def enough_vram(required_bytes): free, total = torch.cuda.mem_get_info() @@ -945,7 +964,6 @@ class HunyuanImage3Model(nn.Module): if self.additional_layers_set and layer_idx <= self.additional_layers_set: pass else: - torch.cuda.synchronize() asyncio.run_coroutine_threadsafe( self.moe_lru._async_offload_to_cpu(layer_idx), self.moe_lru._loop @@ -1025,7 +1043,7 @@ class HunyuanImage3ForCausalMM(nn.Module): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() - cond_exists = (joint_image[:, 0, :] == -100.0).any(dim=1).any() + cond_exists = (joint_image[:, 0, :] != -100.0).any(dim=1).any() gen_timestep_scatter_index = 4 @@ -1048,19 +1066,24 @@ class HunyuanImage3ForCausalMM(nn.Module): if cond_exists: vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 + vae_start += gen_offset + vae_end += gen_offset - vit_start = vae_end + 1 + gen_offset + vit_start = vae_end + 1 vit_end = joint_image.size(1) - 1 + gen_offset joint_slices_i = [ slice(vae_start, vae_end), slice(vit_start, vit_end), ] + else: + joint_slices_i = [] gen_slices_i = [slice(seq_len, gen_offset)] img_slices.append(gen_slices_i + joint_slices_i) img_s = img_slices[0] - rope_image_info = [[(img_s[0], (token_height, token_width)), (img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]] + rope_img = [(img_s[0], (token_height, token_width))] + rope_image_info = [rope_img if len(joint_slices_i) == 0 else rope_img + [(img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]] cond_timestep = torch.zeros(inputs_embeds.size(0)) t_emb = self.time_embed(cond_timestep) @@ -1072,7 +1095,9 @@ class HunyuanImage3ForCausalMM(nn.Module): t_emb = self.time_embed(timestep) x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) - x = torch.cat([timestep_emb, x], dim=1) + inputs_embeds = torch.cat([timestep_emb, x], dim=1) + + input_args = [inputs_embeds, x] if self.first_step else [inputs_embeds] #///////////// # cond_vae_images @@ -1082,9 +1107,9 @@ class HunyuanImage3ForCausalMM(nn.Module): with torch.no_grad(): joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - inputs_embeds = torch.cat([inputs_embeds, x, joint_image], dim = 1) + inputs_embeds = torch.cat([*input_args, joint_image], dim = 1) else: - inputs_embeds = torch.cat([inputs_embeds, x, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token + inputs_embeds = torch.cat([*input_args, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): @@ -1097,7 +1122,7 @@ class HunyuanImage3ForCausalMM(nn.Module): cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=inputs_embeds.shape[1], - n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], + n_elem=128, # head dim base=10000.0, ) custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index d43ae4fce..1989e4b3d 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -79,13 +79,14 @@ class HunyuanImage3Conditioning(io.ComfyNode): patch_embed = model.patch_embed t_embed = model.time_embed + text_tokens = text_encoding[0][0] + batch_size, _, hidden_size = text_tokens.shape + def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ .view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) - text_tokens = text_encoding[0][0] text_tokens = torch.cat([fn("<|startoftext|>"), text_tokens], dim = 1) - batch_size, _, hidden_size = text_tokens.shape if vae_encoding is not None or vit_encoding is not None: vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0)))) @@ -93,11 +94,13 @@ class HunyuanImage3Conditioning(io.ComfyNode): joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn(""), fn("<|endoftext|>")], dim = 1) vae_mask = torch.ones(joint_image.size(1)) vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) + vae_mask = vae_mask.unsqueeze(0).unsqueeze(-1) else: pad_token = torch.tensor([-100.0]).view(1, 1, 1).expand(batch_size, 1, hidden_size) joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1) + vae_mask = torch.empty_like(joint_image) - ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)]) + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask, text_tokens.to(joint_image.dtype)]) uncond_ragged_tensors = None if text_encoding_negative is not None: