mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
updates to the input
This commit is contained in:
parent
a3ac798d4e
commit
86e9a7a669
@ -881,7 +881,6 @@ class HunyuanImage3Model(nn.Module):
|
||||
self.padding_idx = 128009
|
||||
self.vocab_size = 133120
|
||||
self.config = config
|
||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[HunyuanImage3DecoderLayer(config, layer_idx, moe_lru = moe_lru) for layer_idx in range(config["num_hidden_layers"])]
|
||||
)
|
||||
@ -1026,10 +1025,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||
|
||||
gen_timestep_scatter_index = 4
|
||||
cond_exists = (joint_image[:, 0, :] == -100.0).any(dim=1).any()
|
||||
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
||||
gen_timestep_scatter_index = 4
|
||||
|
||||
if cond_exists:
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
||||
|
||||
if self.first_step:
|
||||
token_height, token_width = x[:, -2:, 0].tolist()[0]
|
||||
@ -1039,28 +1041,30 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
token_height, token_width = self.token_dims
|
||||
|
||||
img_slices = []
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
for i in range(x.size(0)):
|
||||
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
||||
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
||||
gen_offset = seq_len + x.size(1)
|
||||
if cond_exists:
|
||||
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
||||
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
||||
|
||||
vit_start = vae_end + 1
|
||||
vit_end = joint_image.size(1) - 1
|
||||
vit_start = vae_end + 1 + gen_offset
|
||||
vit_end = joint_image.size(1) - 1 + gen_offset
|
||||
|
||||
joint_slices_i = [
|
||||
slice(vae_start, vae_end),
|
||||
slice(vit_start, vit_end),
|
||||
]
|
||||
gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)]
|
||||
img_slices.append(joint_slices_i + gen_slices_i)
|
||||
joint_slices_i = [
|
||||
slice(vae_start, vae_end),
|
||||
slice(vit_start, vit_end),
|
||||
]
|
||||
gen_slices_i = [slice(seq_len, gen_offset)]
|
||||
img_slices.append(gen_slices_i + joint_slices_i)
|
||||
|
||||
img_s = img_slices[0]
|
||||
rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]]
|
||||
rope_image_info = [[(img_s[0], (token_height, token_width)), (img_s[1], (384 // 16, 384 // 16)), (img_s[2], (256 // 16, 256 // 16))]]
|
||||
|
||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||
t_emb = self.time_embed(cond_timestep)
|
||||
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
if self.first_step:
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
@ -1074,10 +1078,13 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
# cond_vae_images
|
||||
|
||||
# cond_timestep_scatter_index
|
||||
with torch.no_grad():
|
||||
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
if cond_exists:
|
||||
with torch.no_grad():
|
||||
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
|
||||
inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1)
|
||||
inputs_embeds = torch.cat([inputs_embeds, x, joint_image], dim = 1)
|
||||
else:
|
||||
inputs_embeds = torch.cat([inputs_embeds, x, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token
|
||||
|
||||
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||
for i in range(bsz):
|
||||
@ -1123,7 +1130,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
hidden_states = hidden_states.to(inputs_embeds.device)
|
||||
img_mask = torch.zeros(hidden_states.size(1))
|
||||
img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0
|
||||
img_mask[seq_len + x.size(1)+4:] = 1; img_mask[-1] = 0
|
||||
|
||||
diffusion_prediction = self.ragged_final_layer(
|
||||
hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step)
|
||||
|
||||
@ -32,8 +32,7 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||
|
||||
# may convert clip.tokenizer -> clip.
|
||||
word_embed = model.wte
|
||||
word_embed = clip.wte
|
||||
patch_embed = model.patch_embed
|
||||
t_embed = model.time_embed
|
||||
|
||||
@ -61,37 +60,42 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
||||
display_name="HunyuanImage3Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs = [
|
||||
io.Conditioning.Input("vae_encoding"),
|
||||
io.Conditioning.Input("vit_encoding"),
|
||||
io.Conditioning.Input("text_encoding_positive"),
|
||||
io.Clip.Input("clip"),
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("vae_encoding", optional=True),
|
||||
io.Conditioning.Input("vit_encoding", optional=True),
|
||||
io.Conditioning.Input("text_encoding_negative", optional = True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, model, text_encoding_negative=None):
|
||||
def execute(cls, text_encoding, clip, model, text_encoding_negative=None, vae_encoding = None, vit_encoding = None):
|
||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||
|
||||
word_embed = model.wte
|
||||
word_embed = clip.wte
|
||||
patch_embed = model.patch_embed
|
||||
t_embed = model.time_embed
|
||||
batch_size, _, hidden_size = vit_encoding.shape
|
||||
|
||||
def fn(string, func = encode_fn):
|
||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||
.view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
||||
|
||||
text_tokens = text_encoding[0][0]
|
||||
vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0))))
|
||||
# should dynamically change in model logic
|
||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
|
||||
text_tokens = torch.cat([fn("<|startoftext|>"), text_tokens], dim = 1)
|
||||
batch_size, _, hidden_size = text_tokens.shape
|
||||
|
||||
vae_mask = torch.ones(joint_image.size(1))
|
||||
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
||||
if vae_encoding is not None or vit_encoding is not None:
|
||||
vae_encoding, _, _ = patch_embed(vae_encoding, t_embed(torch.tensor([0]).repeat(vae_encoding.size(0))))
|
||||
# should dynamically change in model logic
|
||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>"), fn("<|endoftext|>")], dim = 1)
|
||||
vae_mask = torch.ones(joint_image.size(1))
|
||||
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
||||
else:
|
||||
pad_token = torch.tensor([-100.0]).view(1, 1, 1).expand(batch_size, 1, hidden_size)
|
||||
joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1)
|
||||
|
||||
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user