mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 14:50:20 +08:00
corrected img_ratio
This commit is contained in:
parent
bd2c2f7375
commit
88c350bfed
@ -15,6 +15,7 @@ from transformers.cache_utils import StaticCache
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional, Tuple, Any, List, Dict
|
from typing import Optional, Tuple, Any, List, Dict
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy_extras.nodes_hunyuan_image import COMPUTED_RESO_GROUPS
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock
|
||||||
|
|
||||||
INIT_MOE = torch.cuda.device_count() != 1
|
INIT_MOE = torch.cuda.device_count() != 1
|
||||||
@ -1049,9 +1050,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
|
return self.model.wte(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=inputs_embeds.device))\
|
||||||
.unsqueeze(0).expand(bsz, -1, -1)
|
.unsqueeze(0).expand(bsz, -1, -1)
|
||||||
|
|
||||||
|
hw = f"{int(height)}x{int(width)}"
|
||||||
|
ratio_idx = [i for i, reso in enumerate(COMPUTED_RESO_GROUPS) if reso == hw][0]
|
||||||
|
img_ratio = fn(f"<img_ratio_{ratio_idx}>", self.special_tok)
|
||||||
|
|
||||||
if cond_exists:
|
if cond_exists:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
joint_image[:, 2:3, :] = fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok) # updates image ratio
|
joint_image[:, 2:3, :] = img_ratio
|
||||||
|
|
||||||
img_slices = []
|
img_slices = []
|
||||||
|
|
||||||
@ -1060,7 +1065,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
|
|
||||||
if self.first_step:
|
if self.first_step:
|
||||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||||
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), fn(f"<img_ratio_{int(height) // int(width)}>", self.special_tok), fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
x = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = self.special_tok), img_ratio, fn("<timestep>", self.special_tok), x, fn("<eoi>")], dim = 1)
|
||||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||||
else:
|
else:
|
||||||
x, token_height, token_width = self.patch_embed(x, t_emb)
|
x, token_height, token_width = self.patch_embed(x, t_emb)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user