From f00f340a56013001a6148eee6d3d00d02078e43e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 17:43:55 -0400 Subject: [PATCH 1/2] Reuse code from flux model. --- comfy/ldm/hidream/model.py | 39 ++++---------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index de749a373..39c67a193 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -8,26 +8,12 @@ from einops import repeat from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps import torch.nn.functional as F -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.flux.layers import LastLayer + from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): @@ -84,23 +70,6 @@ class TimestepEmbed(nn.Module): return t_emb -class OutEmbed(nn.Module): - def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): - super().__init__() - self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) - ) - - def forward(self, x, adaln_input): - shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) - x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - x = self.linear(x) - return x - - def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) @@ -663,7 +632,7 @@ class HiDreamImageTransformer2DModel(nn.Module): ] ) - self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] caption_projection = [] From 9899d187b16a9a823a98fc1df9bf1fbb58674087 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 18:07:55 -0400 Subject: [PATCH 2/2] Limit T5 to 128 tokens for HiDream: #7620 --- comfy/text_encoders/hidream.py | 5 +++-- comfy/text_encoders/sd3_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index af105f9bb..6c34c5572 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -11,14 +11,15 @@ class HiDreamTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data) self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) return out diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 1727998a8..6c2fbeca4 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -32,9 +32,9 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: