diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index 682fd8781..6ccfef697 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -337,7 +337,7 @@ class UNetDown(nn.Module): if self.patch_size == 1: self.model.append(ResBlock( - in_channels=hidden_channels, + channels=hidden_channels, emb_channels=emb_channels, out_channels=out_channels, dropout=dropout, @@ -346,7 +346,7 @@ class UNetDown(nn.Module): else: for i in range(self.patch_size // 2): self.model.append(ResBlock( - in_channels=hidden_channels, + channels=hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels, dropout=dropout, @@ -381,7 +381,7 @@ class UNetUp(nn.Module): if self.patch_size == 1: self.model.append(ResBlock( - in_channels=in_channels, + channels=in_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, @@ -390,7 +390,7 @@ class UNetUp(nn.Module): else: for i in range(self.patch_size // 2): self.model.append(ResBlock( - in_channels=in_channels if i == 0 else hidden_channels, + channels=in_channels if i == 0 else hidden_channels, emb_channels=emb_channels, out_channels=hidden_channels, dropout=dropout, @@ -929,7 +929,7 @@ class HunyuanImage3DecoderLayer(nn.Module): class HunyuanImage3Model(nn.Module): def __init__(self, config, moe_lru=None): - super().__init__(config) + super().__init__() self.padding_idx = 128009 self.vocab_size = 133120 self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx) @@ -989,12 +989,12 @@ class HunyuanImage3Model(nn.Module): class HunyuanImage3ForCausalMM(nn.Module): def __init__(self, config): - super().__init__(config) + super().__init__() self.config = config self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"]) self.patch_embed = UNetDown( - patch_size=16, + patch_size=1, emb_channels=config["hidden_size"], in_channels=32, hidden_channels=1024, @@ -1003,7 +1003,7 @@ class HunyuanImage3ForCausalMM(nn.Module): self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"]) self.final_layer = UNetUp( - patch_size=16, + patch_size=1, emb_channels=config["hidden_size"], in_channels=config["hidden_size"], hidden_channels=1024, @@ -1045,8 +1045,7 @@ class HunyuanImage3ForCausalMM(nn.Module): def forward(self, x, condition, timestep, **kwargs): - cond, uncond = condition[:4], condition[4:] - joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1] + joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind() if self.kv_cache is None: # TODO: should change when higgsv2 gets merged @@ -1058,9 +1057,11 @@ class HunyuanImage3ForCausalMM(nn.Module): ) image_mask = torch.ones(x.size(1)) - image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0) + image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1) gen_timestep_scatter_index = 4 - joint_image[:, 2] = x[:, 2] # updates image ratio + + with torch.no_grad(): + joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1) height, width = x.shape[2] * 16, x.shape[3] * 16 diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4669eb14b..816aed169 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_position_embeddings"] = 12800 dit_config["num_attention_heads"] = 32 dit_config['rms_norm_eps'] = 1e-05 + dit_config["num_hidden_layers"] = 32 return dit_config if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 diff --git a/comfy_extras/nodes_hunyuan_image.py b/comfy_extras/nodes_hunyuan_image.py index ba69c0978..1e748669e 100644 --- a/comfy_extras/nodes_hunyuan_image.py +++ b/comfy_extras/nodes_hunyuan_image.py @@ -30,12 +30,18 @@ class EmptyLatentHunyuanImage3(io.ComfyNode): def execute(cls, height, width, batch_size, clip): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder - def fn(string, func = encode_fn): - return torch.tensor(func(string), device=comfy.model_management.intermediate_device()).unsqueeze(0) + word_embed = clip.tokenizer.wte + + hidden_size = word_embed.weight.shape[1] height, width = get_target_size(height, width) - latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device()) - latent = torch.cat([fn(""), fn("", special_fn), fn(f"", special_fn), latent, fn("")], dim = 1) + latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device()) + + def fn(string, func = encode_fn): + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16) + + latent = torch.cat([fn(""), fn("", func = special_fn), fn(f"", special_fn), fn("", special_fn), latent, fn("")], dim = 1) return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, ) class HunyuanImage3Conditioning(io.ComfyNode): @@ -59,15 +65,20 @@ class HunyuanImage3Conditioning(io.ComfyNode): def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None): encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids special_fn = clip.tokenizer.tokenizer.added_tokens_encoder + + word_embed = clip.tokenizer.wte + batch_size, _, hidden_size = vae_encoding.shape + def fn(string, func = encode_fn): - return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0) + return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\ + .view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size) text_tokens = text_encoding[0][0] # should dynamically change in model logic joint_image = torch.cat([fn(""), fn("", special_fn), fn("", special_fn), fn("", special_fn), vae_encoding, fn(""), vit_encoding, fn("")], dim = 1) vae_mask = torch.ones(joint_image.size(1)) - vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2) + vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:])) ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])