mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
corrected img_ratio
This commit is contained in:
parent
bd2c2f7375
commit
88c350bfed
@ -15,6 +15,7 @@ from transformers.cache_utils import StaticCache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Tuple, Any, List, Dict
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
||||
|
||||
INIT_MOE = torch.cuda.device_count() != 1
|
||||
@ -1048,10 +1049,14 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
def fn(string, func = self.encode_tok):
|
||||
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
|
||||
.unsqueeze(0).expand(bsz, -1, -1)
|
||||
|
||||
hw = f"{int(height)}x{int(width)}"
|
||||
ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0]
|
||||
img_ratio = fn(f"<img_ratio_{ratio_idx}>", self.special_tok)
|
||||
|
||||
if cond_exists:
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio
|
||||
joint_image[:, 2:3, :] = img_ratio
|
||||
|
||||
img_slices = []
|
||||
|
||||
@ -1060,7 +1065,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
if self.first_step:
|
||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
||||
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), img_ratio, fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user