diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index dca38b20e..3cbca46cd 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -733,7 +733,7 @@ class HunyuanMoE(nn.Module): expert = LazyMoELoader() expert = expert.lazy_init(self.config, self.layer_idx, e) self.moe_lru.add_gpu(expert, e + self.layer_idx) - experts_list.append((e, expert)) + experts_list.append((e, expert)) per_pos, per_tokens, per_weights = [], [], [] for e, _ in experts_list: @@ -773,7 +773,8 @@ class HunyuanMoE(nn.Module): x = torch.bmm(tokens_padded, W1_T) x = F.silu(x) - out_padded = torch.bmm(x, W2_T) + x1, x2 = x.chunk(2, dim=2) + out_padded = torch.bmm(x1 * F.silu(x2), W2_T) out_padded = out_padded * weights_padded.unsqueeze(-1) @@ -1025,6 +1026,7 @@ class HunyuanImage3ForCausalMM(nn.Module): self.first_step = True self.kv_cache = None + self.token_dims = () @staticmethod def get_pos_emb(custom_pos_emb, position_ids): @@ -1047,6 +1049,76 @@ class HunyuanImage3ForCausalMM(nn.Module): joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() + gen_timestep_scatter_index = 4 + + with torch.no_grad(): + joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio + + if self.first_step: + token_height, token_width = x[:, -2:, 0].tolist()[0] + self.token_dims = (int(token_height), int(token_width)) + x = x[:, :-2, :] + else: + token_height, token_width = self.token_dims + + img_slices = [] + + for i in range(x.size(0)): + vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0] + vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1 + + vit_start = vae_end + 1 + vit_end = joint_image.size(1) - 1 + + joint_slices_i = [ + slice(vae_start, vae_end), + slice(vit_start, vit_end), + ] + gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)] + img_slices.append(joint_slices_i + gen_slices_i) + + img_s = img_slices[0] + rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]] + + cond_timestep = torch.zeros(inputs_embeds.size(0)) + t_emb = self.time_embed(cond_timestep) + + bsz, seq_len, n_embd = inputs_embeds.shape + + if self.first_step: + x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + else: + t_emb = self.time_embed(timestep) + x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb) + timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) + x = torch.cat([timestep_emb, x], dim=1) + + #///////////// + # cond_vae_images + + # cond_timestep_scatter_index + with torch.no_grad(): + joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) + + inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1) + + attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) + for i in range(bsz): + for _, image_slice in enumerate(img_slices[i]): + attention_mask[i, image_slice, image_slice] = True + attention_mask = attention_mask.unsqueeze(1) + + # pos embed + position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) + cos, sin = build_batch_2d_rope( + image_infos=rope_image_info, + seq_len=inputs_embeds.shape[1], + n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], + base=10000.0, + ) + custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) + custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) + if self.kv_cache is None: # TODO: should change when higgsv2 gets merged self.kv_cache = HunyuanStaticCache( @@ -1056,70 +1128,6 @@ class HunyuanImage3ForCausalMM(nn.Module): dtype=x.dtype, ) - image_mask = torch.ones(x.size(1), device=x.device) - image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) - gen_timestep_scatter_index = 4 - - with torch.no_grad(): - joint_image[:, 2, :] = x[:, 2, :] # updates image ratio - - position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) - height, width = x.shape[2] * 16, x.shape[3] * 16 - token_height = height // (16 * 16) - token_width = width // (16 * 16) - - batch_image_slices = [] - for i in range(x.size(0)): - # slice the vae and vit parts + slice the latent from x - joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)] - gen_slices_i = [slice(3, x[i].size(1) - 1)] - batch_image_slices.append(joint_slices_i + gen_slices_i) - - rope_image_info = [ - [(s, (token_height, token_width)) for s in slices_i] - for slices_i in batch_image_slices - ] - seq_len = inputs_embeds.shape[1] - cos, sin = build_batch_2d_rope( - image_infos=rope_image_info, - seq_len=seq_len, - n_elem=self.config["hidden_size"] // self.config["num_attention_heads"], - base=10000.0, - ) - custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device)) - - custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) - - cond_timestep = torch.zeros(inputs_embeds.size(0)) - t_emb = self.time_embed(cond_timestep) - - bsz, seq_len, n_embd = inputs_embeds.shape - - # FIXME: token_h and token_w for the first step - if self.first_step: - x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - else: - t_emb = self.time_embed(timestep) - x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb) - timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd) - x = torch.cat([timestep_emb, x], dim=1) - - inputs_embeds = torch.cat([inputs_embeds, x], dim = 1) - - #///////////// - # cond_vae_images - - # cond_timestep_scatter_index - joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) - - inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1) - - attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1) - for i in range(bsz): - for _, image_slice in enumerate(batch_image_slices[i]): - attention_mask[i, image_slice, image_slice] = True - attention_mask = attention_mask.unsqueeze(1) - outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, @@ -1137,8 +1145,11 @@ class HunyuanImage3ForCausalMM(nn.Module): self.kv_cache = past_key_value hidden_states = hidden_states.to(inputs_embeds.device) + img_mask = torch.zeros(hidden_states.size(1)) + img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0 + diffusion_prediction = self.ragged_final_layer( - hidden_states, image_mask, timestep, token_h, token_w, self.first_step) + hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step) if self.first_step: self.first_step = False diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..510cb9da7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -268,7 +268,11 @@ class ResBlock(TimestepBlock): if emb_out is not None: if self.exchange_temb_dims: emb_out = emb_out.movedim(1, 2) - h = h + emb_out + try: + h = h + emb_out + except: + emb_out = emb_out.movedim(1, 2) + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index 7da2e5718..012ac8a08 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -39,13 +39,17 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): height, width = get_target_size(height, width) latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) - latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) + latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size))) + + def tk_fn(token): + return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1)) def fn(string, func = encode_fn): return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ .unsqueeze(0).expand(batch_size, -1, -1) latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) + latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -87,7 +91,7 @@ class HunyuanImage3Conditioning(io.ComfyNode): vae_mask = torch.ones(joint_image.size(1)) vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) - ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)]) + ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)]) uncond_ragged_tensors = None if text_encoding_negative is not None: