corrected img_ratio

This commit is contained in:
Yousef Rafat 2025-11-28 20:43:01 +02:00
parent bd2c2f7375
commit 88c350bfed

View File

@ -15,6 +15,7 @@ from transformers.cache_utils import StaticCache
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention
from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
INIT_MOE = torch.cuda.device_count() != 1
@ -1048,10 +1049,14 @@ class HunyuanImage3ForCausalMM(nn.Module):
def fn(string, func = self.encode_tok):
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
.unsqueeze(0).expand(bsz, -1, -1)
hw = f"{int(height)}x{int(width)}"
ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0]
img_ratio = fn(f"<img_ratio_{ratio_idx}>", self.special_tok)
if cond_exists:
with torch.no_grad():
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio
joint_image[:, 2:3, :] = img_ratio
img_slices = []
@ -1060,7 +1065,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
if self.first_step:
x, token_height, token_width = self.patch_embed(x, t_emb)
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), img_ratio, fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
else:
x, token_height, token_width = self.patch_embed(x, t_emb)