mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 14:20:27 +08:00
fixes from testing
This commit is contained in:
parent
ca119c44fb
commit
9e9c536c8e
@ -337,7 +337,7 @@ class UNetDown(nn.Module):
|
|||||||
|
|
||||||
if self.patch_size == 1:
|
if self.patch_size == 1:
|
||||||
self.model.append(ResBlock(
|
self.model.append(ResBlock(
|
||||||
in_channels=hidden_channels,
|
channels=hidden_channels,
|
||||||
emb_channels=emb_channels,
|
emb_channels=emb_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -346,7 +346,7 @@ class UNetDown(nn.Module):
|
|||||||
else:
|
else:
|
||||||
for i in range(self.patch_size // 2):
|
for i in range(self.patch_size // 2):
|
||||||
self.model.append(ResBlock(
|
self.model.append(ResBlock(
|
||||||
in_channels=hidden_channels,
|
channels=hidden_channels,
|
||||||
emb_channels=emb_channels,
|
emb_channels=emb_channels,
|
||||||
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -381,7 +381,7 @@ class UNetUp(nn.Module):
|
|||||||
|
|
||||||
if self.patch_size == 1:
|
if self.patch_size == 1:
|
||||||
self.model.append(ResBlock(
|
self.model.append(ResBlock(
|
||||||
in_channels=in_channels,
|
channels=in_channels,
|
||||||
emb_channels=emb_channels,
|
emb_channels=emb_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -390,7 +390,7 @@ class UNetUp(nn.Module):
|
|||||||
else:
|
else:
|
||||||
for i in range(self.patch_size // 2):
|
for i in range(self.patch_size // 2):
|
||||||
self.model.append(ResBlock(
|
self.model.append(ResBlock(
|
||||||
in_channels=in_channels if i == 0 else hidden_channels,
|
channels=in_channels if i == 0 else hidden_channels,
|
||||||
emb_channels=emb_channels,
|
emb_channels=emb_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -929,7 +929,7 @@ class HunyuanImage3DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class HunyuanImage3Model(nn.Module):
|
class HunyuanImage3Model(nn.Module):
|
||||||
def __init__(self, config, moe_lru=None):
|
def __init__(self, config, moe_lru=None):
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.padding_idx = 128009
|
self.padding_idx = 128009
|
||||||
self.vocab_size = 133120
|
self.vocab_size = 133120
|
||||||
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
self.wte = nn.Embedding(133120, config["hidden_size"], self.padding_idx)
|
||||||
@ -989,12 +989,12 @@ class HunyuanImage3Model(nn.Module):
|
|||||||
|
|
||||||
class HunyuanImage3ForCausalMM(nn.Module):
|
class HunyuanImage3ForCausalMM(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.timestep_emb = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||||
self.patch_embed = UNetDown(
|
self.patch_embed = UNetDown(
|
||||||
patch_size=16,
|
patch_size=1,
|
||||||
emb_channels=config["hidden_size"],
|
emb_channels=config["hidden_size"],
|
||||||
in_channels=32,
|
in_channels=32,
|
||||||
hidden_channels=1024,
|
hidden_channels=1024,
|
||||||
@ -1003,7 +1003,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
self.time_embed = TimestepEmbedder(hidden_size=config["hidden_size"])
|
||||||
|
|
||||||
self.final_layer = UNetUp(
|
self.final_layer = UNetUp(
|
||||||
patch_size=16,
|
patch_size=1,
|
||||||
emb_channels=config["hidden_size"],
|
emb_channels=config["hidden_size"],
|
||||||
in_channels=config["hidden_size"],
|
in_channels=config["hidden_size"],
|
||||||
hidden_channels=1024,
|
hidden_channels=1024,
|
||||||
@ -1045,8 +1045,7 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, condition, timestep, **kwargs):
|
def forward(self, x, condition, timestep, **kwargs):
|
||||||
|
|
||||||
cond, uncond = condition[:4], condition[4:]
|
joint_image, cond_vae_image_mask, input_ids, uncond_joint, uncond_vae_mask, uncond_inputs = condition.unbind()
|
||||||
joint_image, cond_vae_image_mask, input_ids = cond[0], cond[1]
|
|
||||||
|
|
||||||
if self.kv_cache is None:
|
if self.kv_cache is None:
|
||||||
# TODO: should change when higgsv2 gets merged
|
# TODO: should change when higgsv2 gets merged
|
||||||
@ -1058,9 +1057,11 @@ class HunyuanImage3ForCausalMM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_mask = torch.ones(x.size(1))
|
image_mask = torch.ones(x.size(1))
|
||||||
image_mask[:, :3] = torch.zeros(5); image_mask[:, -1] = torch.zeros(0)
|
image_mask[:3] = torch.zeros(3); image_mask[-1] = torch.zeros(1)
|
||||||
gen_timestep_scatter_index = 4
|
gen_timestep_scatter_index = 4
|
||||||
joint_image[:, 2] = x[:, 2] # updates image ratio
|
|
||||||
|
with torch.no_grad():
|
||||||
|
joint_image[:, 2, 0] = x[:, 2, 0, 0] # updates image ratio
|
||||||
|
|
||||||
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=x.device)[None].expand(x.size(0), -1)
|
||||||
height, width = x.shape[2] * 16, x.shape[3] * 16
|
height, width = x.shape[2] * 16, x.shape[3] * 16
|
||||||
|
|||||||
@ -490,6 +490,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["max_position_embeddings"] = 12800
|
dit_config["max_position_embeddings"] = 12800
|
||||||
dit_config["num_attention_heads"] = 32
|
dit_config["num_attention_heads"] = 32
|
||||||
dit_config['rms_norm_eps'] = 1e-05
|
dit_config['rms_norm_eps'] = 1e-05
|
||||||
|
dit_config["num_hidden_layers"] = 32
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2
|
||||||
|
|||||||
@ -30,12 +30,18 @@ class EmptyLatentHunyuanImage3(io.ComfyNode):
|
|||||||
def execute(cls, height, width, batch_size, clip):
|
def execute(cls, height, width, batch_size, clip):
|
||||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||||
def fn(string, func = encode_fn):
|
word_embed = clip.tokenizer.wte
|
||||||
return torch.tensor(func(string), device=comfy.model_management.intermediate_device()).unsqueeze(0)
|
|
||||||
|
hidden_size = word_embed.weight.shape[1]
|
||||||
|
|
||||||
height, width = get_target_size(height, width)
|
height, width = get_target_size(height, width)
|
||||||
latent = torch.randn(batch_size, 32, height // 16, width // 16, device=comfy.model_management.intermediate_device())
|
latent = torch.randn(batch_size, 32, int(height) // 16, int(width) // 16, device=comfy.model_management.intermediate_device())
|
||||||
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn(f"<img_ratio_{height / width}", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
|
||||||
|
def fn(string, func = encode_fn):
|
||||||
|
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||||
|
.view(1, hidden_size, 1, 1).expand(batch_size, hidden_size, int(height) // 16, int(width) // 16)
|
||||||
|
|
||||||
|
latent = torch.cat([fn("<boi>"), fn("<img_size_1024>", func = special_fn), fn(f"<img_ratio_{int(height) // int(width)}>", special_fn), fn("<timestep>", special_fn), latent, fn("<eoi>")], dim = 1)
|
||||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
return io.NodeOutput({"samples": latent, "type": "hunyuan_image_3"}, )
|
||||||
|
|
||||||
class HunyuanImage3Conditioning(io.ComfyNode):
|
class HunyuanImage3Conditioning(io.ComfyNode):
|
||||||
@ -59,15 +65,20 @@ class HunyuanImage3Conditioning(io.ComfyNode):
|
|||||||
def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None):
|
def execute(cls, vae_encoding, vit_encoding, text_encoding, clip, text_encoding_negative=None):
|
||||||
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
encode_fn = clip.tokenizer.tokenizer.convert_tokens_to_ids
|
||||||
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
special_fn = clip.tokenizer.tokenizer.added_tokens_encoder
|
||||||
|
|
||||||
|
word_embed = clip.tokenizer.wte
|
||||||
|
batch_size, _, hidden_size = vae_encoding.shape
|
||||||
|
|
||||||
def fn(string, func = encode_fn):
|
def fn(string, func = encode_fn):
|
||||||
return torch.tensor(func(string), device=text_encoding.device).unsqueeze(0)
|
return word_embed(torch.tensor(func(string) if not isinstance(func, dict) else func[string], device=comfy.model_management.intermediate_device()))\
|
||||||
|
.view(1, hidden_size, 1, 1).view(1, 1, hidden_size).expand(batch_size, -1, hidden_size)
|
||||||
|
|
||||||
text_tokens = text_encoding[0][0]
|
text_tokens = text_encoding[0][0]
|
||||||
# should dynamically change in model logic
|
# should dynamically change in model logic
|
||||||
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
|
joint_image = torch.cat([fn("<boi>"), fn("<img_size_1024>", special_fn), fn("<img_ratio_3>", special_fn), fn("<timestep>", special_fn), vae_encoding, fn("<joint_img_sep>"), vit_encoding, fn("<eoi>")], dim = 1)
|
||||||
|
|
||||||
vae_mask = torch.ones(joint_image.size(1))
|
vae_mask = torch.ones(joint_image.size(1))
|
||||||
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(2)
|
vae_mask[:3] = torch.zeros(3); vae_mask[vae_encoding.size(1) + 4:] = torch.zeros(len(vae_mask[vae_encoding.size(1) + 4:]))
|
||||||
|
|
||||||
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])
|
ragged_tensors = torch.nested.nested_tensor([joint_image, vae_mask.unsqueeze(0).unsqueeze(-1), text_tokens.unsqueeze(-1).to(joint_image.dtype)])
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user