corrected img_ratio

This commit is contained in:
Yousef Rafat 2025-11-28 20:43:01 +02:00
parent bd2c2f7375
commit 88c350bfed

View File

@ -15,6 +15,7 @@ from transformers.cache_utils import StaticCache
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple, Any, List, Dict from typing import Optional, Tuple, Any, List, Dict
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
INIT_MOE = torch.cuda.device_count() != 1 INIT_MOE = torch.cuda.device_count() != 1
@ -1049,9 +1050,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\ return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
.unsqueeze(0).expand(bsz, -1, -1) .unsqueeze(0).expand(bsz, -1, -1)
hw = f"{int(height)}x{int(width)}"
ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0]
img_ratio = fn(f"<img_ratio_{ratio_idx}>", self.special_tok)
if cond_exists: if cond_exists:
with torch.no_grad(): with torch.no_grad():
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio joint_image[:, 2:3, :] = img_ratio
img_slices = [] img_slices = []
@ -1060,7 +1065,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
if self.first_step: if self.first_step:
x, token_height, token_width = self.patch_embed(x, t_emb) x, token_height, token_width = self.patch_embed(x, t_emb)
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1) x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), img_ratio, fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd) x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
else: else:
x, token_height, token_width = self.patch_embed(x, t_emb) x, token_height, token_width = self.patch_embed(x, t_emb)