mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 07:40:21 +08:00
a bunch of fixes
This commit is contained in:
parent
1a25a0ad69
commit
10a17dc85d
@ -1053,6 +1053,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
gen_timestep_scatter_index = 4
|
gen_timestep_scatter_index = 4
|
||||||
cond, uncond = condition[:4], condition[4:]
|
cond, uncond = condition[:4], condition[4:]
|
||||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
||||||
|
joint_image[:, 2] = x[:, 2] # updates image ratio
|
||||||
|
|
||||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||||
@ -1079,11 +1080,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
|
|
||||||
if self.first_step:
|
if self.first_step:
|
||||||
t_emb = self.time_embed(timestep)
|
t_emb = self.time_embed(timestep)
|
||||||
x[:, 5:-4], token_h, token_w = self.patch_embed(x[:, 5:-4], t_emb)
|
x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb)
|
||||||
x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
x[:, gen_timestep_scatter_index] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||||
else:
|
else:
|
||||||
t_emb = self.time_embed(timestep)
|
t_emb = self.time_embed(timestep)
|
||||||
x[:, 5:-4], token_h, token_w = self.patch_embed(x, t_emb)
|
x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb)
|
||||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||||
x = torch.cat([timestep_emb, x], dim=1)
|
x = torch.cat([timestep_emb, x], dim=1)
|
||||||
|
|
||||||
@ -1095,16 +1096,19 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
# cond_timestep_scatter_index
|
# cond_timestep_scatter_index
|
||||||
joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||||
# conditioning images (vae)
|
# conditioning images (vae)
|
||||||
joint_image[:, 7:cond_vae_image_mask.size(0)], token_h, token_w = self.patch_embed(
|
joint_image[:, 3:cond_vae_image_mask.size(0)+3], token_h, token_w = self.patch_embed(
|
||||||
joint_image[:, 7:cond_vae_image_mask.size(0)], self.time_embed(cond_timestep)
|
joint_image[:, 3:cond_vae_image_mask.size(0)+3], self.time_embed(cond_timestep)
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1)
|
inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1)
|
||||||
|
|
||||||
batch_image_slices = [
|
batch_image_slices = []
|
||||||
input_ids[i] + x[i]
|
for i in range(x.size(0)):
|
||||||
for i in range(bsz)
|
# slice the vae and vit parts + slice the latent from x
|
||||||
]
|
joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)]
|
||||||
|
gen_slices_i = [slice(3, x[i].size(1) - 1)]
|
||||||
|
batch_image_slices.append(joint_slices_i + gen_slices_i)
|
||||||
|
|
||||||
attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
for _, image_slice in enumerate(batch_image_slices[i]):
|
for _, image_slice in enumerate(batch_image_slices[i]):
|
||||||
|
|||||||
@ -35,8 +35,7 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
|||||||
|
|
||||||
height, width = get_target_size(height, width)
|
height, width = get_target_size(height, width)
|
||||||
latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device())
|
latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device())
|
||||||
latent = torch.cat([fn("<boi>"), fn("<all_img>_start"), fn("<img_size_1024>", special_fn), fn(f"<img_ratio_{height / width}", special_fn), fn("<timestep>", special_fn),
|
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn(f"<img_ratio_{height / width}", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
||||||
latent, fn("<eoi>"), fn("<img>_start"), fn("<img>_end"), fn("<all_img>_end")], dim = 1)
|
|
||||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
||||||
|
|
||||||
class HunyuanImage3Conditioning(io.ComfyNode):
|
class HunyuanImage3Conditioning(io.ComfyNode):
|
||||||
@ -63,51 +62,29 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
|||||||
def fn(string, func = encode_fn):
|
def fn(string, func = encode_fn):
|
||||||
return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0)
|
return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0)
|
||||||
|
|
||||||
text_encoding = text_encoding[0][0]
|
text_tokens = text_encoding[0][0]
|
||||||
|
# should dynamically change in model logic
|
||||||
|
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
|
||||||
|
|
||||||
text_tokens = torch.cat([fn("<text>_start"), text_encoding, fn("<text>_end")], dim = 1)
|
vae_mask = torch.ones(joint_image.size(1))
|
||||||
vae_tokens = torch.cat([fn("<vae_img>_start"), fn("<joint_img>_start"), fn("<all_img>_start"), vae_encoding, fn("<vae_img>_end"), fn("<all_img>_end"), fn("<joint_img_sep>")], dim = 1)
|
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2)
|
||||||
vit_tokens = torch.cat([fn("<vit_img>_start"), fn("<all_img>_start"), vit_encoding, fn("<vit_img>_end"), fn("<joint_img>_end"), fn("<all_img>_end")], dim = 1)
|
|
||||||
n, seq_len, dim = vit_tokens.shape
|
|
||||||
vit_tokens = vit_tokens.reshape(n * seq_len, dim)
|
|
||||||
# should dynamically change in model logic
|
|
||||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_tokens, vit_tokens, fn("<eoi>")], dim = 1)
|
|
||||||
|
|
||||||
seq_len_total = joint_image.shape[1]
|
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])
|
||||||
mask = torch.zeros(seq_len_total, dtype=torch.bool, device=joint_image.device)
|
|
||||||
positions = {}
|
|
||||||
current = 4
|
|
||||||
|
|
||||||
def mark_region(name, tensor):
|
|
||||||
nonlocal current
|
|
||||||
start = current
|
|
||||||
current += tensor.shape[1]
|
|
||||||
end = current - 1
|
|
||||||
positions[f"<{name}>_start"] = start
|
|
||||||
positions[f"<{name}>_end"] = end
|
|
||||||
mask[start:end + 1] = True
|
|
||||||
return start, end
|
|
||||||
|
|
||||||
mark_region("vae_img", vae_tokens)
|
|
||||||
|
|
||||||
mask_list = []
|
|
||||||
for prefix in ["text", "vae_img", "vit_img"]:
|
|
||||||
start = positions[f"<{prefix}>_start"]
|
|
||||||
end = positions[f"<{prefix}>_end"]
|
|
||||||
|
|
||||||
section_mask = torch.arange(start, end + 1, device=mask.device)
|
|
||||||
mask_list.append(section_mask)
|
|
||||||
|
|
||||||
mask_list.insert(0, joint_image)
|
|
||||||
mask_list.append(text_tokens)
|
|
||||||
ragged_tensors = torch.nested.nested_tensor(mask_list, dtype=torch.long)
|
|
||||||
|
|
||||||
|
uncond_ragged_tensors = None
|
||||||
if text_encoding_negative is not None:
|
if text_encoding_negative is not None:
|
||||||
uncond_ragged_tensors = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None)
|
uncond_ragged_tensors, _ = cls.execute(vae_encoding, vit_encoding, text_encoding_negative, clip=clip, text_encoding_negative = None)
|
||||||
else:
|
else:
|
||||||
uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()])
|
uncond_ragged_tensors = torch.nested.nested_tensor([torch.zeros_like(t) for t in ragged_tensors.unbind()])
|
||||||
|
|
||||||
return ragged_tensors, uncond_ragged_tensors
|
if uncond_ragged_tensors is not None:
|
||||||
|
positive = [[ragged_tensors, {}]]
|
||||||
|
negative = [[uncond_ragged_tensors, {}]]
|
||||||
|
else:
|
||||||
|
positive = ragged_tensors
|
||||||
|
negative = uncond_ragged_tensors
|
||||||
|
|
||||||
|
return positive, negative
|
||||||
|
|
||||||
class Image3Extension(ComfyExtension):
|
class Image3Extension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user