mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 08:10:51 +08:00
removed all errors
This commit is contained in:
parent
5056a1f4d4
commit
44346c4251
@ -733,7 +733,7 @@ class HunyuanMoE(nn.Module):
|
||||
expert = LazyMoELoader()
|
||||
expert = expert.lazy_init(self.config, self.layer_idx, e)
|
||||
self.moe_lru.add_gpu(expert, e + self.layer_idx)
|
||||
experts_list.append((e, expert))
|
||||
experts_list.append((e, expert))
|
||||
|
||||
per_pos, per_tokens, per_weights = [], [], []
|
||||
for e, _ in experts_list:
|
||||
@ -773,7 +773,8 @@ class HunyuanMoE(nn.Module):
|
||||
x = torch.bmm(tokens_padded, W1_T)
|
||||
x = F.silu(x)
|
||||
|
||||
out_padded = torch.bmm(x, W2_T)
|
||||
x1, x2 = x.chunk(2, dim=2)
|
||||
out_padded = torch.bmm(x1 * F.silu(x2), W2_T)
|
||||
|
||||
out_padded = out_padded * weights_padded.unsqueeze(-1)
|
||||
|
||||
@ -1025,6 +1026,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
self.first_step = True
|
||||
|
||||
self.kv_cache = None
|
||||
self.token_dims = ()
|
||||
|
||||
@staticmethod
|
||||
def get_pos_emb(custom_pos_emb, position_ids):
|
||||
@ -1047,6 +1049,76 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
|
||||
joint_image, cond_vae_image_mask, inputs_embeds, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||
|
||||
gen_timestep_scatter_index = 4
|
||||
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2:3, :] = x[:, 2:3, :] # updates image ratio
|
||||
|
||||
if self.first_step:
|
||||
token_height, token_width = x[:, -2:, 0].tolist()[0]
|
||||
self.token_dims = (int(token_height), int(token_width))
|
||||
x = x[:, :-2, :]
|
||||
else:
|
||||
token_height, token_width = self.token_dims
|
||||
|
||||
img_slices = []
|
||||
|
||||
for i in range(x.size(0)):
|
||||
vae_mask_indices = (cond_vae_image_mask[i].squeeze(-1) == 1).nonzero(as_tuple=True)[0]
|
||||
vae_start, vae_end = vae_mask_indices[0].item(), vae_mask_indices[-1].item() + 1
|
||||
|
||||
vit_start = vae_end + 1
|
||||
vit_end = joint_image.size(1) - 1
|
||||
|
||||
joint_slices_i = [
|
||||
slice(vae_start, vae_end),
|
||||
slice(vit_start, vit_end),
|
||||
]
|
||||
gen_slices_i = [slice(3 + vit_end, x[i].size(0) - 1 + vit_end)]
|
||||
img_slices.append(joint_slices_i + gen_slices_i)
|
||||
|
||||
img_s = img_slices[0]
|
||||
rope_image_info = [[(img_s[0], (384 // 16, 384 // 16)), (img_s[1], (256 // 16, 256 // 16)), (img_s[2], (token_height, token_width))]]
|
||||
|
||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||
t_emb = self.time_embed(cond_timestep)
|
||||
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
if self.first_step:
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
t_emb = self.time_embed(timestep)
|
||||
x[:, 3:-1], token_height, token_width = self.patch_embed(x[:, 3:-1], t_emb)
|
||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||
x = torch.cat([timestep_emb, x], dim=1)
|
||||
|
||||
#/////////////
|
||||
# cond_vae_images
|
||||
|
||||
# cond_timestep_scatter_index
|
||||
with torch.no_grad():
|
||||
joint_image[:, 3:4, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
|
||||
inputs_embeds = torch.cat([inputs_embeds, joint_image, x], dim = 1)
|
||||
|
||||
attention_mask = torch.ones(inputs_embeds.shape[1], inputs_embeds.shape[1], dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||
for i in range(bsz):
|
||||
for _, image_slice in enumerate(img_slices[i]):
|
||||
attention_mask[i, image_slice, image_slice] = True
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# pos embed
|
||||
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
cos, sin = build_batch_2d_rope(
|
||||
image_infos=rope_image_info,
|
||||
seq_len=inputs_embeds.shape[1],
|
||||
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
||||
base=10000.0,
|
||||
)
|
||||
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
||||
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
||||
|
||||
if self.kv_cache is None:
|
||||
# TODO: should change when higgsv2 gets merged
|
||||
self.kv_cache = HunyuanStaticCache(
|
||||
@ -1056,70 +1128,6 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
image_mask = torch.ones(x.size(1), device=x.device)
|
||||
image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1)
|
||||
gen_timestep_scatter_index = 4
|
||||
|
||||
with torch.no_grad():
|
||||
joint_image[:, 2, :] = x[:, 2, :] # updates image ratio
|
||||
|
||||
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||
token_height = height // (16 * 16)
|
||||
token_width = width // (16 * 16)
|
||||
|
||||
batch_image_slices = []
|
||||
for i in range(x.size(0)):
|
||||
# slice the vae and vit parts + slice the latent from x
|
||||
joint_slices_i = [slice(3, cond_vae_image_mask[i].size(0) + 3), slice(cond_vae_image_mask[i].size(0) + 4, joint_image.size(1) - 1)]
|
||||
gen_slices_i = [slice(3, x[i].size(1) - 1)]
|
||||
batch_image_slices.append(joint_slices_i + gen_slices_i)
|
||||
|
||||
rope_image_info = [
|
||||
[(s, (token_height, token_width)) for s in slices_i]
|
||||
for slices_i in batch_image_slices
|
||||
]
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
cos, sin = build_batch_2d_rope(
|
||||
image_infos=rope_image_info,
|
||||
seq_len=seq_len,
|
||||
n_elem=self.config["hidden_size"] // self.config["num_attention_heads"],
|
||||
base=10000.0,
|
||||
)
|
||||
custom_pos_emb = (sin.to(position_ids.device), cos.to(position_ids.device))
|
||||
|
||||
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
|
||||
|
||||
cond_timestep = torch.zeros(inputs_embeds.size(0))
|
||||
t_emb = self.time_embed(cond_timestep)
|
||||
|
||||
bsz, seq_len, n_embd = inputs_embeds.shape
|
||||
|
||||
# FIXME: token_h and token_w for the first step
|
||||
if self.first_step:
|
||||
x[:, gen_timestep_scatter_index:gen_timestep_scatter_index+1, :] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
else:
|
||||
t_emb = self.time_embed(timestep)
|
||||
x[:, 3:-1], token_h, token_w = self.patch_embed(x[:, 3:-1], t_emb)
|
||||
timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
|
||||
x = torch.cat([timestep_emb, x], dim=1)
|
||||
|
||||
inputs_embeds = torch.cat([inputs_embeds, x], dim = 1)
|
||||
|
||||
#/////////////
|
||||
# cond_vae_images
|
||||
|
||||
# cond_timestep_scatter_index
|
||||
joint_image[:, 3] = self.timestep_emb(timestep.reshape(-1)).reshape(bsz, -1, n_embd)
|
||||
|
||||
inputs_embeds = torch.cat([inputs_embeds, joint_image], dim = 1)
|
||||
|
||||
attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
|
||||
for i in range(bsz):
|
||||
for _, image_slice in enumerate(batch_image_slices[i]):
|
||||
attention_mask[i, image_slice, image_slice] = True
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
outputs = self.model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -1137,8 +1145,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
||||
self.kv_cache = past_key_value
|
||||
|
||||
hidden_states = hidden_states.to(inputs_embeds.device)
|
||||
img_mask = torch.zeros(hidden_states.size(1))
|
||||
img_mask[-x.size(1)+4:] = 1; img_mask[-1] = 0
|
||||
|
||||
diffusion_prediction = self.ragged_final_layer(
|
||||
hidden_states, image_mask, timestep, token_h, token_w, self.first_step)
|
||||
hidden_states, img_mask, timestep, int(token_height), int(token_width), self.first_step)
|
||||
|
||||
if self.first_step:
|
||||
self.first_step = False
|
||||
|
||||
@ -268,7 +268,11 @@ class ResBlock(TimestepBlock):
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = emb_out.movedim(1, 2)
|
||||
h = h + emb_out
|
||||
try:
|
||||
h = h + emb_out
|
||||
except:
|
||||
emb_out = emb_out.movedim(1, 2)
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
@ -39,13 +39,17 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
||||
height, width = get_target_size(height, width)
|
||||
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
|
||||
|
||||
latent, _, _ = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size)))
|
||||
latent, tk_height, tk_width = patch_embed(latent, t_embed(torch.tensor([0]).repeat(batch_size)))
|
||||
|
||||
def tk_fn(token):
|
||||
return torch.tensor([token], device = latent.device, dtype = latent.dtype).unsqueeze(1).expand(batch_size, 1, latent.size(-1))
|
||||
|
||||
def fn(string, func = encode_fn):
|
||||
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||
.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
||||
latent = torch.cat([latent, tk_fn(tk_height), tk_fn(tk_width)], dim = 1)
|
||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
||||
|
||||
class HunyuanImage3Conditioning(io.ComfyNode):
|
||||
@ -87,7 +91,7 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
||||
vae_mask = torch.ones(joint_image.size(1))
|
||||
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
||||
|
||||
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])
|
||||
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.to(joint_image.dtype)])
|
||||
|
||||
uncond_ragged_tensors = None
|
||||
if text_encoding_negative is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user