mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
important fixes
This commit is contained in:
parent
9e9c536c8e
commit
5056a1f4d4
@ -1045,51 +1045,59 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
def forward(self, x, condition, timestep, **kwargs):
|
||||
|
||||
joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||
|
||||
if self.kv_cache is None:
|
||||
# TODO: should change when higgsv2 gets merged
|
||||
self.kv_cache = HunyuanStaticCache(
|
||||
config=self.config,
|
||||
batch_size=x.size(0) * 2,
|
||||
max_cache_len = input_ids.shape[1],
|
||||
max_cache_len = inputs_embeds.shape[1],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
image_mask = torch.ones(x.size(1))
|
||||
image_mask = torch.ones(x.size(1), device=x.device)
|
||||
image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1)
|
||||
gen_timestep_scatter_index = 4
|
||||
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio
|
||||
joint_image[:, 2, :] = x[:, 2, :] # updates image ratio
|
||||
|
||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||
token_height = height // (16 * 16)
|
||||
token_width = width // (16 * 16)
|
||||
|
||||
rope_image_info = [[(None, (token_height, token_width))] * 2]
|
||||
seq_len = input_ids.shape[1]
|
||||
batch_image_slices = []
|
||||
for i in range(x.size(0)):
|
||||
# slice the vae and vit parts + slice the latent from x
|
||||
joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)]
|
||||
gen_slices_i = [slice(3, x[i].size(1) - 1)]
|
||||
batch_image_slices.append(joint_slices_i + gen_slices_i)
|
||||
|
||||
rope_image_info = [
|
||||
[(s, (token_height, token_width)) for s in slices_i]
|
||||
for slices_i in batch_image_slices
|
||||
]
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
cos, sin = build_batch_2d_rope(
|
||||
image_infos=rope_image_info,
|
||||
seq_len=seq_len,
|
||||
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
||||
base=10000.0,
|
||||
)
|
||||
custom_pos_emb = (sin, cos)
|
||||
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
||||
|
||||
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
||||
inputs_embeds = self.model.wte(input_ids)
|
||||
|
||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||
t_emb = self.time_embed(cond_timestep)
|
||||
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
# FIXME: token_h and token_w for the first step
|
||||
if self.first_step:
|
||||
t_emb = self.time_embed(timestep)
|
||||
x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb)
|
||||
x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
t_emb = self.time_embed(timestep)
|
||||
x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb)
|
||||
@ -1103,20 +1111,9 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
# cond_timestep_scatter_index
|
||||
joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
# conditioning images (vae)
|
||||
joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed(
|
||||
joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep)
|
||||
)
|
||||
|
||||
inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1)
|
||||
|
||||
batch_image_slices = []
|
||||
for i in range(x.size(0)):
|
||||
# slice the vae and vit parts + slice the latent from x
|
||||
joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)]
|
||||
gen_slices_i = [slice(3, x[i].size(1) - 1)]
|
||||
batch_image_slices.append(joint_slices_i + gen_slices_i)
|
||||
|
||||
attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||
for i in range(bsz):
|
||||
for _, image_slice in enumerate(batch_image_slices[i]):
|
||||
@ -1139,7 +1136,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
if past_key_value is not None:
|
||||
self.kv_cache = past_key_value
|
||||
|
||||
hidden_states = hidden_states.to(input_ids.device)
|
||||
hidden_states = hidden_states.to(inputs_embeds.device)
|
||||
diffusion_prediction = self.ragged_final_layer(
|
||||
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)
|
||||
|
||||
|
||||
@ -30,16 +30,20 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
||||
def execute(cls, height, width, batch_size, clip):
|
||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||
word_embed = clip.tokenizer.wte
|
||||
|
||||
hidden_size = word_embed.weight.shape[1]
|
||||
# may convert clip.tokenizer -> clip.
|
||||
word_embed = clip.tokenizer.wte
|
||||
patch_embed = clip.tokenizer.patch_embed
|
||||
t_embed = clip.tokenizer.time_embed
|
||||
|
||||
height, width = get_target_size(height, width)
|
||||
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
|
||||
|
||||
latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size)))
|
||||
|
||||
def fn(string, func = encode_fn):
|
||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||
.view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16)
|
||||
.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
||||
@ -67,13 +71,16 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||
|
||||
word_embed = clip.tokenizer.wte
|
||||
batch_size, _, hidden_size = vae_encoding.shape
|
||||
patch_embed = clip.tokenizer.patch_embed
|
||||
t_embed = clip.tokenizer.time_embed
|
||||
batch_size, _, hidden_size = vit_encoding.shape
|
||||
|
||||
def fn(string, func = encode_fn):
|
||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||
.view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
||||
.view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
||||
|
||||
text_tokens = text_encoding[0][0]
|
||||
vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0))))
|
||||
# should dynamically change in model logic
|
||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user