From 88c350bfed1a959585cfc3f08f036e1af906bdd7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:43:01 +0200 Subject: [PATCH] corrected img_ratio --- comfy/ldm/hunyuan_image_3/model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan_image_3/model.py b/comfy/ldm/hunyuan_image_3/model.py index f4cf935ef..099732a6c 100644 --- a/comfy/ldm/hunyuan_image_3/model.py +++ b/comfy/ldm/hunyuan_image_3/model.py @@ -15,6 +15,7 @@ from transformers.cache_utils import StaticCache from concurrent.futures import ThreadPoolExecutor from typing import Optional, Tuple, Any, List, Dict from comfy.ldm.modules.attention import optimized_attention +from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock INIT_MOE = torch.cuda.device_count() != 1 @@ -1048,10 +1049,14 @@ class HunyuanImage3ForCausalMM(nn.Module): def fn(string, func = self.encode_tok): return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\ .unsqueeze(0).expand(bsz, -1, -1) + + hw = f"{int(height)}x{int(width)}" + ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0] + img_ratio = fn(f"", self.special_tok) if cond_exists: with torch.no_grad(): - joint_image[:, 2:3, :] = fn(f"", self.special_tok) # updates image ratio + joint_image[:, 2:3, :] = img_ratio img_slices = [] @@ -1060,7 +1065,7 @@ class HunyuanImage3ForCausalMM(nn.Module): if self.first_step: x, token_height, token_width = self.patch_embed(x, t_emb) - x = torch.cat([fn(""), fn("", func = self.special_tok), fn(f"", self.special_tok), fn("", self.special_tok), x, fn("")], dim = 1) + x = torch.cat([fn(""), fn("", func = self.special_tok), img_ratio, fn("", self.special_tok), x, fn("")], dim = 1) x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) else: x, token_height, token_width = self.patch_embed(x, t_emb)