From 10a17dc85d6d1701d9742d81e552ab654ad98723 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 1 Nov 2025 16:40:49 +0200 Subject: [PATCH] a bunch of fixes --- comfy/ldm/hunyuan_image_3/model.py | 20 ++++++---- comfy_extras/nodes_hunyuan_image.py | 57 +++++++++-------------------- 2 files changed, 29 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 949769839..ba2c1e90c 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1053,6 +1053,7 @@ class HunyuanImage3ForCausalMM(nn.Module): gen_timestep_scatter_index = 4 cond, uncond = condition[:4], condition[4:] joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + joint_image[:, 2] = x[:, 2] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 @@ -1079,11 +1080,11 @@ class HunyuanImage3ForCausalMM(nn.Module): if self.first_step: t_emb = self.time_embed(timestep) - x[:, 5:-4], token_h, token_w = self.patch_embed(x[:, 5:-4], t_emb) + x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) - x[:, 5:-4], token_h, token_w = self.patch_embed(x, t_emb) + x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) x = torch.cat([timestep_emb, x], dim=1) @@ -1095,16 +1096,19 @@ class HunyuanImage3ForCausalMM(nn.Module): # cond_timestep_scatter_index joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) # conditioning images (vae) - joint_image[:, 7:cond_vae_image_mask.size(0)], token_h, token_w = self.patch_embed( - joint_image[:, 7:cond_vae_image_mask.size(0)], self.time_embed(cond_timestep) + joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed( + joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep) ) inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - batch_image_slices = [ - input_ids[i] + x[i] - for i in range(bsz) - ] + batch_image_slices = [] + for i in range(x.size(0)): + # slice the vae and vit parts + slice the latent from x + joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] + gen_slices_i = [slice(3, x[i].size(1) - 1)] + batch_image_slices.append(joint_slices_i + gen_slices_i) + attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(batch_image_slices[i]): diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index ada042fc5..ba69c0978 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -35,8 +35,7 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device()) - latent = torch.cat([fn(""), fn("_start"), fn("", special_fn), fn(f"", special_fn), - latent, fn(""), fn("_start"), fn("_end"), fn("_end")], dim = 1) + latent = torch.cat([fn(""), fn("", special_fn), fn(f"", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -63,51 +62,29 @@ class HunyuanImage3Conditioning(io.ComfyNode): def fn(string, func = encode_fn): return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0) - text_encoding = text_encoding[0][0] + text_tokens = text_encoding[0][0] + # should dynamically change in model logic + joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) - text_tokens = torch.cat([fn("_start"), text_encoding, fn("_end")], dim = 1) - vae_tokens = torch.cat([fn("_start"), fn("_start"), fn("_start"), vae_encoding, fn("_end"), fn("_end"), fn("")], dim = 1) - vit_tokens = torch.cat([fn("_start"), fn("_start"), vit_encoding, fn("_end"), fn("_end"), fn("_end")], dim = 1) - n, seq_len, dim = vit_tokens.shape - vit_tokens = vit_tokens.reshape(n * seq_len, dim) - # should dynamically change in model logic - joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_tokens, vit_tokens, fn("")], dim = 1) + vae_mask = torch.ones(joint_image.size(1)) + vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2) - seq_len_total = joint_image.shape[1] - mask = torch.zeros(seq_len_total, dtype=torch.bool, device=joint_image.device) - positions = {} - current = 4 - - def mark_region(name, tensor): - nonlocal current - start = current - current += tensor.shape[1] - end = current - 1 - positions[f"<{name}>_start"] = start - positions[f"<{name}>_end"] = end - mask[start:end + 1] = True - return start, end - - mark_region("vae_img", vae_tokens) - - mask_list = [] - for prefix in ["text", "vae_img", "vit_img"]: - start = positions[f"<{prefix}>_start"] - end = positions[f"<{prefix}>_end"] - - section_mask = torch.arange(start, end + 1, device=mask.device) - mask_list.append(section_mask) - - mask_list.insert(0, joint_image) - mask_list.append(text_tokens) - ragged_tensors = torch.nested.nested_tensor(mask_list, dtype=torch.long) + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)]) + uncond_ragged_tensors = None if text_encoding_negative is not None: - uncond_ragged_tensors = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) + uncond_ragged_tensors, _ = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None) else: uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()]) - return ragged_tensors, uncond_ragged_tensors + if uncond_ragged_tensors is not None: + positive = [[ragged_tensors, {}]] + negative = [[uncond_ragged_tensors, {}]] + else: + positive = ragged_tensors + negative = uncond_ragged_tensors + + return positive, negative class Image3Extension(ComfyExtension): @override