diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 6ccfef697..dca38b20e 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -1045,51 +1045,59 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): - joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( config=self.config, batch_size=x.size(0) * 2, - max_cache_len = input_ids.shape[1], + max_cache_len = inputs_embeds.shape[1], dtype=x.dtype, ) - image_mask = torch.ones(x.size(1)) + image_mask = torch.ones(x.size(1), device=x.device) image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) gen_timestep_scatter_index = 4 with torch.no_grad(): - joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio + joint_image[:, 2, :] = x[:, 2, :] # updates image ratio - position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 token_height = height // (16 * 16) token_width = width // (16 * 16) - rope_image_info = [[(None, (token_height, token_width))] * 2] - seq_len = input_ids.shape[1] + batch_image_slices = [] + for i in range(x.size(0)): + # slice the vae and vit parts + slice the latent from x + joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] + gen_slices_i = [slice(3, x[i].size(1) - 1)] + batch_image_slices.append(joint_slices_i + gen_slices_i) + + rope_image_info = [ + [(s, (token_height, token_width)) for s in slices_i] + for slices_i in batch_image_slices + ] + seq_len = inputs_embeds.shape[1] cos, sin = build_batch_2d_rope( image_infos=rope_image_info, seq_len=seq_len, n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], base=10000.0, ) - custom_pos_emb = (sin, cos) + custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) - inputs_embeds = self.model.wte(input_ids) cond_timestep = torch.zeros(inputs_embeds.size(0)) t_emb = self.time_embed(cond_timestep) bsz, seq_len, n_embd = inputs_embeds.shape + # FIXME: token_h and token_w for the first step if self.first_step: - t_emb = self.time_embed(timestep) - x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) - x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: t_emb = self.time_embed(timestep) x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) @@ -1103,20 +1111,9 @@ class HunyuanImage3ForCausalMM(nn.Module): # cond_timestep_scatter_index joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - # conditioning images (vae) - joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed( - joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep) - ) inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - batch_image_slices = [] - for i in range(x.size(0)): - # slice the vae and vit parts + slice the latent from x - joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] - gen_slices_i = [slice(3, x[i].size(1) - 1)] - batch_image_slices.append(joint_slices_i + gen_slices_i) - attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) for i in range(bsz): for _, image_slice in enumerate(batch_image_slices[i]): @@ -1139,7 +1136,7 @@ class HunyuanImage3ForCausalMM(nn.Module): if past_key_value is not None: self.kv_cache = past_key_value - hidden_states = hidden_states.to(input_ids.device) + hidden_states = hidden_states.to(inputs_embeds.device) diffusion_prediction = self.ragged_final_layer( hidden_states, image_mask, timestep, token_h, token_w, self.first_step) diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 1e748669e..7da2e5718 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -30,16 +30,20 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): def execute(cls, height, width, batch_size, clip): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - word_embed = clip.tokenizer.wte - hidden_size = word_embed.weight.shape[1] + # may convert clip.tokenizer -> clip. + word_embed = clip.tokenizer.wte + patch_embed = clip.tokenizer.patch_embed + t_embed = clip.tokenizer.time_embed height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) + + latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ - .view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16) + .unsqueeze(0).expand(batch_size, -1, -1) latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) @@ -67,13 +71,16 @@ class HunyuanImage3Conditioning(io.ComfyNode): special_fn = clip.tokenizer.tokenizer.added_tokens_encoder word_embed = clip.tokenizer.wte - batch_size, _, hidden_size = vae_encoding.shape + patch_embed = clip.tokenizer.patch_embed + t_embed = clip.tokenizer.time_embed + batch_size, _, hidden_size = vit_encoding.shape def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ - .view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) + .view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) text_tokens = text_encoding[0][0] + vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0)))) # should dynamically change in model logic joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1)