removed additional token injection

This commit is contained in:
Yousef Rafat 2025-11-28 18:21:20 +02:00
parent 334041f6a6
commit bd2c2f7375
2 changed files with 2 additions and 2 deletions

View File

@ -1103,7 +1103,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
inputs_embeds = torch.cat([*input_args, joint_image], dim = 1)
else:
inputs_embeds = torch.cat([*input_args, joint_image[:, 1:, :]], dim = 1) # joint_image == eos_token
inputs_embeds = torch.cat([*input_args], dim = 1)
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
for i in range(bsz):

View File

@ -76,7 +76,7 @@ class HunyuanImage3Conditioning(io.ComfyNode):
vae_mask = vae_mask.unsqueeze(0).unsqueeze(-1)
else:
pad_token = torch.tensor([-100.0]).view(1, 1, 1).expand(batch_size, 1, hidden_size)
joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1)
joint_image = torch.cat([pad_token, fn("<|endoftext|>")], dim = 1) # look into
vae_mask = torch.empty_like(joint_image)
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask, text_tokens.to(joint_image.dtype)])