fixes from testing

This commit is contained in:
Yousef Rafat 2025-11-04 23:55:16 +02:00
parent ca119c44fb
commit 9e9c536c8e
3 changed files with 31 additions and 18 deletions

View File

@ -337,7 +337,7 @@ class UNetDown(nn.Module):
if self.patch_size == 1:
self.model.append(ResBlock(
in_channels=hidden_channels,
channels=hidden_channels,
emb_channels=emb_channels,
out_channels=out_channels,
dropout=dropout,
@ -346,7 +346,7 @@ class UNetDown(nn.Module):
else:
for i in range(self.patch_size // 2):
self.model.append(ResBlock(
in_channels=hidden_channels,
channels=hidden_channels,
emb_channels=emb_channels,
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
dropout=dropout,
@ -381,7 +381,7 @@ class UNetUp(nn.Module):
if self.patch_size == 1:
self.model.append(ResBlock(
in_channels=in_channels,
channels=in_channels,
emb_channels=emb_channels,
out_channels=hidden_channels,
dropout=dropout,
@ -390,7 +390,7 @@ class UNetUp(nn.Module):
else:
for i in range(self.patch_size // 2):
self.model.append(ResBlock(
in_channels=in_channels if i == 0 else hidden_channels,
channels=in_channels if i == 0 else hidden_channels,
emb_channels=emb_channels,
out_channels=hidden_channels,
dropout=dropout,
@ -929,7 +929,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
class HunyuanImage3Model(nn.Module):
def __init__(self, config, moe_lru=None):
super().__init__(config)
super().__init__()
self.padding_idx = 128009
self.vocab_size = 133120
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
@ -989,12 +989,12 @@ class HunyuanImage3Model(nn.Module):
class HunyuanImage3ForCausalMM(nn.Module):
def __init__(self, config):
super().__init__(config)
super().__init__()
self.config = config
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
self.patch_embed = UNetDown(
patch_size=16,
patch_size=1,
emb_channels=config["hidden_size"],
in_channels=32,
hidden_channels=1024,
@ -1003,7 +1003,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
self.final_layer = UNetUp(
patch_size=16,
patch_size=1,
emb_channels=config["hidden_size"],
in_channels=config["hidden_size"],
hidden_channels=1024,
@ -1045,8 +1045,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
def forward(self, x, condition, timestep, **kwargs):
cond, uncond = condition[:4], condition[4:]
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
if self.kv_cache is None:
# TODO: should change when higgsv2 gets merged
@ -1058,9 +1057,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
)
image_mask = torch.ones(x.size(1))
image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0)
image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1)
gen_timestep_scatter_index = 4
joint_image[:, 2] = x[:, 2] # updates image ratio
with torch.no_grad():
joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
height, width = x.shape[2] * 16, x.shape[3] * 16

View File

@ -490,6 +490,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_position_embeddings"] = 12800
dit_config["num_attention_heads"] = 32
dit_config['rms_norm_eps'] = 1e-05
dit_config["num_hidden_layers"] = 32
return dit_config
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2

View File

@ -30,12 +30,18 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
def execute(cls, height, width, batch_size, clip):
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
def fn(string, func = encode_fn):
return torch.tensor(func(string), device=comfy.model_management.intermediate_device()).unsqueeze(0)
word_embed = clip.tokenizer.wte
hidden_size = word_embed.weight.shape[1]
height, width = get_target_size(height, width)
latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device())
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn(f"<img_ratio_{height / width}", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
def fn(string, func = encode_fn):
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
.view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16)
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
class HunyuanImage3Conditioning(io.ComfyNode):
@ -59,15 +65,20 @@ class HunyuanImage3Conditioning(io.ComfyNode):
def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None):
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
word_embed = clip.tokenizer.wte
batch_size, _, hidden_size = vae_encoding.shape
def fn(string, func = encode_fn):
return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0)
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
.view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
text_tokens = text_encoding[0][0]
# should dynamically change in model logic
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
vae_mask = torch.ones(joint_image.size(1))
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2)
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])