From 3e8155f7a3d7601838bbc82a8ccf550343bbb132 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 10:32:21 -0400 Subject: [PATCH 01/27] More flexible long clip support. Add clip g long clip support. Text encoder refactor. Support llama models with different vocab sizes. --- comfy/sd1_clip.py | 23 ++++++++++++++--- comfy/sdxl_clip.py | 14 +++++------ comfy/text_encoders/aura_t5.py | 2 +- comfy/text_encoders/cosmos.py | 2 +- comfy/text_encoders/flux.py | 10 +++----- comfy/text_encoders/genmo.py | 2 +- comfy/text_encoders/hunyuan_video.py | 22 ++++++++++------- comfy/text_encoders/hydit.py | 8 +++--- comfy/text_encoders/llama.py | 14 ++++++++++- comfy/text_encoders/long_clipl.py | 37 ++++++++++++++-------------- comfy/text_encoders/lt.py | 2 +- comfy/text_encoders/lumina2.py | 2 +- comfy/text_encoders/pixart_t5.py | 2 +- comfy/text_encoders/sa_t5.py | 2 +- comfy/text_encoders/sd2_clip.py | 2 +- comfy/text_encoders/sd3_clip.py | 15 ++++++----- comfy/text_encoders/wan.py | 2 +- 17 files changed, 95 insertions(+), 66 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index be21ec18d..2ca5ed9ba 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): LAYERS = [ "last", "pooled", - "hidden" + "hidden", + "all" ] def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, @@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if "model_name" not in model_options: + model_options = {**model_options, "model_name": "clip_l"} if isinstance(textmodel_json_config, dict): config = textmodel_json_config @@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) + te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {}) + for k, v in te_model_options.items(): + config[k] = v + operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: + if self.layer == "all": + pass + elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + if self.layer == "all": + intermediate_output = "all" + else: + intermediate_output = self.layer_idx + + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() @@ -447,7 +461,7 @@ class SDTokenizer: if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) - self.max_length = max_length + self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None @@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) clip_model = model_options.get("{}_class".format(self.clip), clip_model) + model_options = {**model_options, "model_name": self.clip} setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 5b7c8a412..ea7f5d10f 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) @@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -41,8 +41,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index e9ad45a7f..cf4252eea 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel): class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data) class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 5441c8952..a1adb5242 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data) class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index a12995ec0..0666dde7f 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -9,14 +9,13 @@ import os class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a480..9dcf190a2 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index dbb259e54..33ac22497 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data) class LLAMAModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) if llama_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = llama_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + textmodel_json_config = {} + vocab_size = model_options.get("vocab_size", None) + if vocab_size is not None: + textmodel_json_config["vocab_size"] = vocab_size + + model_options = {**model_options, "model_name": "llama"} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens - self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) + self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} @@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) self.dtypes = set([dtype, dtype_llama]) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7da3e9fc5..e7273f425 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -9,24 +9,26 @@ import torch class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") + model_options = {**model_options, "model_name": "hydit_clip"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data) class MT5XLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") + model_options = {**model_options, "model_name": "mt5xl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model") tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -35,7 +37,7 @@ class HyditTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) - self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) + self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 58710b2bf..34eb870e3 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -268,11 +268,17 @@ class Llama2_(nn.Module): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output for i, layer in enumerate(self.layers): + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -283,6 +289,12 @@ class Llama2_(nn.Module): intermediate = x.clone() x = self.norm(x) + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + if intermediate is not None and final_layer_norm_intermediate: intermediate = self.norm(intermediate) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index b81912cb3..f9483b427 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,30 +1,29 @@ from comfy import sd1_clip import os -class LongClipTokenizer_(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - -class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, *args, **kwargs): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) - -class LongClipTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) - -class LongClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) def model_options_long_clip(sd, tokenizer_data, model_options): w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None) + else: + model_name = "clip_g" + if w is None: w = sd.get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: + if w is not None: + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + model_name = "clip_g" + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + model_name = "clip_l" + else: + model_name = "clip_l" + + if w is not None: tokenizer_data = tokenizer_data.copy() model_options = model_options.copy() - tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ - model_options["clip_l_class"] = LongClipModel_ + model_config = model_options.get("model_config", {}) + model_config["max_position_embeddings"] = w.shape[0] + model_options["{}_model_config".format(model_name)] = model_config + tokenizer_data["{}_max_length".format(model_name)] = w.shape[0] return tokenizer_data, model_options diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5c2ce583f..48ea67e67 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,7 +6,7 @@ import comfy.text_encoders.genmo class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128? class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index a7b1d702b..674461b75 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -6,7 +6,7 @@ import comfy.text_encoders.llama class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index d56d57f1b..b8de6bc4e 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 7778ce47a..2803926ac 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel): class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 31fc89869..700a23bf0 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 3ad2ed93a..1727998a8 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 971ac8fa8..d50fa4b28 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} From 6fc5dbd52ab70952020e6bc486c4d851a7ba6625 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 12:13:28 -0400 Subject: [PATCH 02/27] Cleanup. --- comfy/rmsnorm.py | 11 ----------- comfy/text_encoders/long_clipl.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 81b3e9062..77df44464 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -27,17 +27,6 @@ def rms_norm(x, weight=None, eps=1e-6): if RMSNorm is None: class RMSNorm(torch.nn.Module): - def __init__( - self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs - ): - super().__init__() - self.eps = eps - self.learnable_scale = elementwise_affine - if self.learnable_scale: - self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) - else: - self.register_parameter("weight", None) - def __init__( self, normalized_shape, diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index f9483b427..8d4c7619d 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,5 +1,3 @@ -from comfy import sd1_clip -import os def model_options_long_clip(sd, tokenizer_data, model_options): From 9ad792f92706e2179c58b2e5348164acafa69288 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 17:35:05 -0400 Subject: [PATCH 03/27] Basic support for hidream i1 model. --- comfy/ldm/hidream/model.py | 828 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 18 + comfy/model_detection.py | 19 + comfy/ops.py | 3 + comfy/sd.py | 4 + comfy/supported_models.py | 32 +- comfy/text_encoders/hidream.py | 150 ++++++ comfy_extras/nodes_hidream.py | 32 ++ nodes.py | 3 +- 9 files changed, 1087 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/hidream/model.py create mode 100644 comfy/text_encoders/hidream.py create mode 100644 comfy_extras/nodes_hidream.py diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py new file mode 100644 index 000000000..de749a373 --- /dev/null +++ b/comfy/ldm/hidream/model.py @@ -0,0 +1,828 @@ +from typing import Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +import torch.nn.functional as F + +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + dtype=None, device=None, operations=None + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, latent): + latent = self.proj(latent) + return latent + + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +class OutEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + + def forward(self, x, adaln_input): + shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) + x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = self.linear(x) + return x + + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) + + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + +class HiDreamAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False, + dtype=None, device=None, operations=None + ): + # super(Attention, self).__init__() + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = operations.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + self.processor = processor + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + dtype=None, device=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device)) + self.reset_parameters() + + def reset_parameters(self) -> None: + pass + # import torch.nn.init as init + # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + dtype=None, device=None, operations=None + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if True: # self.training: # TODO: check which branch performs faster + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = torch.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + #y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + + +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device) + ) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + + +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device) + ) + # nn.init.zeros_(self.adaLN_modulation[1].weight) + # nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + + +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + dtype=None, device=None, operations=None + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + + +class HiDreamImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + image_model=None, + dtype=None, device=None, operations=None + ): + self.patch_size = patch_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.num_layers = num_layers + self.num_single_layers = num_single_layers + + self.gradient_checkpointing = False + + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + dtype=dtype, device=device, operations=operations + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_single_layers) + ] + ) + + self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]: + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.patch_size, p2=self.patch_size) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.patch_size * self.patch_size + if isinstance(x, torch.Tensor): + B = x.shape[0] + device = x.device + dtype = x.dtype + else: + B = len(x) + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + control = None, + transformer_options = {}, + ) -> torch.Tensor: + hidden_states = x + timesteps = t + pooled_embeds = y + T5_encoder_hidden_states = context + + img_sizes = None + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + # T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states = block( + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=None, + adaln_input=adaln_input, + rope=rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes) + return -output diff --git a/comfy/model_base.py b/comfy/model_base.py index 6bc627ae3..8dab1740b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model +import comfy.ldm.hidream.model import comfy.model_management import comfy.patcher_extension @@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HiDream(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + conditioning_llama3 = kwargs.get("conditioning_llama3", None) + if conditioning_llama3 is not None: + out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4217f5831..a4da1afcd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -338,6 +338,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream + dit_config = {} + dit_config["image_model"] = "hidream" + dit_config["attention_head_dim"] = 128 + dit_config["axes_dims_rope"] = [64, 32, 32] + dit_config["caption_channels"] = [4096, 4096] + dit_config["max_resolution"] = [128, 128] + dit_config["in_channels"] = 16 + dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31] + dit_config["num_attention_heads"] = 20 + dit_config["num_routed_experts"] = 4 + dit_config["num_activated_experts"] = 2 + dit_config["num_layers"] = 16 + dit_config["num_single_layers"] = 32 + dit_config["out_channels"] = 16 + dit_config["patch_size"] = 2 + dit_config["text_emb_dim"] = 2048 + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/ops.py b/comfy/ops.py index 6b0e29307..aae6cafac 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -263,6 +263,9 @@ class manual_cast(disable_weight_init): class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True + class RMSNorm(disable_weight_init.RMSNorm): + comfy_cast_weights = True + class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True diff --git a/comfy/sd.py b/comfy/sd.py index 4d3aef3e1..d97873ba2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.hidream import comfy.model_patcher import comfy.lora @@ -853,6 +854,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif len(clip_data) == 3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif len(clip_data) == 4: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer parameters = 0 for c in clip_data: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2a6a61560..81c47ac68 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +class HiDream(supported_models_base.BASE): + unet_config = { + "image_model": "hidream", + } + + sampling_settings = { + "shift": 3.0, + } + + sampling_settings = { + } + + # memory_usage_factor = 1.2 # TODO + + unet_extra_config = {} + latent_format = latent_formats.Flux + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HiDream(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py new file mode 100644 index 000000000..af105f9bb --- /dev/null +++ b/comfy/text_encoders/hidream.py @@ -0,0 +1,150 @@ +from . import hunyuan_video +from . import sd3_clip +from comfy import sd1_clip +from comfy import sdxl_clip +import comfy.model_management +import torch +import logging + + +class HiDreamTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data) + self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return self.clip_g.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class HiDreamTEModel(torch.nn.Module): + def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + if clip_l: + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_g = None + + if t5: + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) + self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True) + self.dtypes.add(dtype_t5) + else: + self.t5xxl = None + + if llama: + dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) + if "vocab_size" not in model_options: + model_options["vocab_size"] = 128256 + self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009}) + self.dtypes.add(dtype_llama) + else: + self.llama = None + + logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama)) + + def set_clip_options(self, options): + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) + if self.llama is not None: + self.llama.set_clip_options(options) + + def reset_clip_options(self): + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() + if self.llama is not None: + self.llama.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_g = token_weight_pairs["g"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_llama = token_weight_pairs["llama"] + lg_out = None + pooled = None + extra = {} + + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + else: + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_out, t5_pooled = t5_output[:2] + + if self.llama is not None: + ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) + ll_out, ll_pooled = ll_output[:2] + ll_out = ll_out[:, 1:] + + if t5_out is None: + t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) + + if ll_out is None: + ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) + + extra["conditioning_llama3"] = ll_out + return t5_out, pooled, extra + + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: + return self.t5xxl.load_sd(sd) + else: + return self.llama.load_sd(sd) + + +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): + class HiDreamTEModel_(HiDreamTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return HiDreamTEModel_ diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py new file mode 100644 index 000000000..5a160c2ba --- /dev/null +++ b/comfy_extras/nodes_hidream.py @@ -0,0 +1,32 @@ +import folder_paths +import comfy.sd +import comfy.model_management + + +class QuadrupleCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" + + def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + + +NODE_CLASS_MAPPINGS = { + "QuadrupleCLIPLoader": QuadrupleCLIPLoader, +} diff --git a/nodes.py b/nodes.py index e66b5c714..ae0a2e183 100644 --- a/nodes.py +++ b/nodes.py @@ -2280,7 +2280,8 @@ def init_builtin_extra_nodes(): "nodes_hunyuan3d.py", "nodes_primitive.py", "nodes_cfg.py", - "nodes_optimalsteps.py" + "nodes_optimalsteps.py", + "nodes_hidream.py" ] import_failed = [] From b4dc03ad7669b155d3c7714e9e5a474365d50c8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 04:53:56 -0400 Subject: [PATCH 04/27] Fix issue on old torch. --- comfy/rmsnorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 77df44464..9d82bee1a 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -49,6 +49,7 @@ if RMSNorm is None: ) else: self.register_parameter("weight", None) + self.bias = None def forward(self, x): return rms_norm(x, self.weight, self.eps) From cce1d9145e06c0f86336a2a7f5558610fdc76718 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 16 Apr 2025 15:41:00 -0400 Subject: [PATCH 05/27] [Type] Mark input options NotRequired (#7614) --- comfy/comfy_types/node_typing.py | 54 ++++++++++++++++---------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 3535966fb..42ed5174e 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes """ - default: bool | str | float | int | list | tuple + default: NotRequired[bool | str | float | int | list | tuple] """The default value of the widget""" - defaultInput: bool + defaultInput: NotRequired[bool] """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. - defaultInput on required inputs should be dropped. - defaultInput on optional inputs should be replaced with forceInput. Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 """ - forceInput: bool + forceInput: NotRequired[bool] """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" - lazy: bool + lazy: NotRequired[bool] """Declares that this input uses lazy evaluation""" - rawLink: bool + rawLink: NotRequired[bool] """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" - tooltip: str + tooltip: NotRequired[str] """Tooltip for the input (or widget), shown on pointer hover""" # class InputTypeNumber(InputTypeOptions): # default: float | int - min: float + min: NotRequired[float] """The minimum value of a number (``FLOAT`` | ``INT``)""" - max: float + max: NotRequired[float] """The maximum value of a number (``FLOAT`` | ``INT``)""" - step: float + step: NotRequired[float] """The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)""" - round: float + round: NotRequired[float] """Floats are rounded by this value (``FLOAT``)""" # class InputTypeBoolean(InputTypeOptions): # default: bool - label_on: str + label_on: NotRequired[str] """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_off: str + label_off: NotRequired[str] """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str - multiline: bool + multiline: NotRequired[bool] """Use a multiline text box (``STRING``)""" - placeholder: str + placeholder: NotRequired[str] """Placeholder text to display in the UI when empty (``STRING``)""" # Deprecated: # defaultVal: str - dynamicPrompts: bool + dynamicPrompts: NotRequired[bool] """Causes the front-end to evaluate dynamic prompts (``STRING``)""" # class InputTypeCombo(InputTypeOptions): - image_upload: bool + image_upload: NotRequired[bool] """Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`.""" - image_folder: Literal["input", "output", "temp"] + image_folder: NotRequired[Literal["input", "output", "temp"]] """Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """ - remote: RemoteInputOptions + remote: NotRequired[RemoteInputOptions] """Specifies the configuration for a remote input. Available after ComfyUI frontend v1.9.7 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422""" - control_after_generate: bool + control_after_generate: NotRequired[bool] """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" options: NotRequired[list[str | int | float]] """COMBO type only. Specifies the selectable options for the combo widget. @@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict): class HiddenInputTypeDict(TypedDict): """Provides type hinting for the hidden entry of node INPUT_TYPES.""" - node_id: Literal["UNIQUE_ID"] + node_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - unique_id: Literal["UNIQUE_ID"] + unique_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - prompt: Literal["PROMPT"] + prompt: NotRequired[Literal["PROMPT"]] """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" - extra_pnginfo: Literal["EXTRA_PNGINFO"] + extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]] """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" - dynprompt: Literal["DYNPROMPT"] + dynprompt: NotRequired[Literal["DYNPROMPT"]] """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" @@ -187,11 +187,11 @@ class InputTypeDict(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs """ - required: dict[str, tuple[IO, InputTypeOptions]] + required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes all inputs that must be connected for the node to execute.""" - optional: dict[str, tuple[IO, InputTypeOptions]] + optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes inputs which do not need to be connected.""" - hidden: HiddenInputTypeDict + hidden: NotRequired[HiddenInputTypeDict] """Offers advanced functionality and server-client communication. Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs From f00f340a56013001a6148eee6d3d00d02078e43e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 17:43:55 -0400 Subject: [PATCH 06/27] Reuse code from flux model. --- comfy/ldm/hidream/model.py | 39 ++++---------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index de749a373..39c67a193 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -8,26 +8,12 @@ from einops import repeat from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps import torch.nn.functional as F -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.flux.layers import LastLayer + from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): @@ -84,23 +70,6 @@ class TimestepEmbed(nn.Module): return t_emb -class OutEmbed(nn.Module): - def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): - super().__init__() - self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) - ) - - def forward(self, x, adaln_input): - shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) - x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - x = self.linear(x) - return x - - def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) @@ -663,7 +632,7 @@ class HiDreamImageTransformer2DModel(nn.Module): ] ) - self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] caption_projection = [] From 9899d187b16a9a823a98fc1df9bf1fbb58674087 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 18:07:55 -0400 Subject: [PATCH 07/27] Limit T5 to 128 tokens for HiDream: #7620 --- comfy/text_encoders/hidream.py | 5 +++-- comfy/text_encoders/sd3_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index af105f9bb..6c34c5572 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -11,14 +11,15 @@ class HiDreamTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data) self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) return out diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 1727998a8..6c2fbeca4 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -32,9 +32,9 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: From 1fc00ba4b6576ed5910a88caa47866774ee6d0ca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 18:34:14 -0400 Subject: [PATCH 08/27] Make hidream work with any latent resolution. --- comfy/ldm/hidream/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index 39c67a193..fcb5a9c51 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.ldm.common_dit # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py @@ -701,7 +702,8 @@ class HiDreamImageTransformer2DModel(nn.Module): control = None, transformer_options = {}, ) -> torch.Tensor: - hidden_states = x + bs, c, h, w = x.shape + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) timesteps = t pooled_embeds = y T5_encoder_hidden_states = context @@ -794,4 +796,4 @@ class HiDreamImageTransformer2DModel(nn.Module): hidden_states = hidden_states[:, :image_tokens_seq_len, ...] output = self.final_layer(hidden_states, adaln_input) output = self.unpatchify(output, img_sizes) - return -output + return -output[:, :, :h, :w] From 0d720e4367c1c149dbfa0a98ebd81c7776914545 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 06:25:39 -0400 Subject: [PATCH 09/27] Don't hardcode length of context_img in wan code. --- comfy/ldm/wan/model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9b5e5332c..d64e73a8e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=None, ): r""" Args: @@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x @@ -420,9 +421,12 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -430,12 +434,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e) From c14429940f6f9491c77250eb15cad3746e350753 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 12:04:48 -0400 Subject: [PATCH 10/27] Support loading WAN FLF model. --- comfy/ldm/wan/model.py | 13 +++++++++++-- comfy/model_detection.py | 3 +++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index d64e73a8e..8907f70ad 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -251,7 +251,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -259,7 +259,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -285,6 +293,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -374,7 +383,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index a4da1afcd..6499bf238 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D From dbcfd092a29c272696bae856d943005fc0cc3036 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 12:42:34 -0400 Subject: [PATCH 11/27] Set default context_img_len to 257 --- comfy/ldm/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 8907f70ad..2a30497c5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -193,7 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, - context_img_len=None, + context_img_len=257, ): r""" Args: From eba7a25e7abf9ec47ab2a42a5c1e6a5cf52351e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 13:18:43 -0400 Subject: [PATCH 12/27] Add WanFirstLastFrameToVideo node to use the new model. --- comfy_extras/nodes_wan.py | 98 ++++++++++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2d0f31ac8..8ad358ce8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -4,6 +4,7 @@ import torch import comfy.model_management import comfy.utils import comfy.latent_formats +import comfy.clip_vision class WanImageToVideo: @@ -99,6 +100,72 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFirstLastFrameToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), + "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_start_image is not None: + clip_vision_output = clip_vision_start_image + + if clip_vision_end_image is not None: + if clip_vision_output is not None: + states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2) + clip_vision_output = comfy.clip_vision.Output() + clip_vision_output.penultimate_hidden_states = states + else: + clip_vision_output = clip_vision_end_image + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + class WanFunInpaintToVideo: @classmethod def INPUT_TYPES(s): @@ -122,38 +189,13 @@ class WanFunInpaintToVideo: CATEGORY = "conditioning/video_models" def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - if start_image is not None: - start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - if end_image is not None: - end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + flfv = WanFirstLastFrameToVideo() + return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) - image = torch.ones((length, height, width, 3)) * 0.5 - mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) - - if start_image is not None: - image[:start_image.shape[0]] = start_image - mask[:, :, :start_image.shape[0] + 3] = 0.0 - - if end_image is not None: - image[-end_image.shape[0]:] = end_image - mask[:, :, -end_image.shape[0]:] = 0.0 - - concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - out_latent = {} - out_latent["samples"] = latent - return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, + "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, } From 05d5a75cdcb749286c9ce9e034bb37a2f6195c37 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 18 Apr 2025 02:25:33 +0800 Subject: [PATCH 13/27] Update frontend to 1.16 (Install templates as pip package) (#7623) * install templates as pip package * Update requirements.txt * bump templates version to include hidream --------- Co-authored-by: Chenlei Hu --- app/frontend_management.py | 21 +++++++++++++++++++++ requirements.txt | 3 ++- server.py | 6 ++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index c56ea86e0..7b7923b79 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -184,6 +184,27 @@ comfyui-frontend-package is not installed. ) sys.exit(-1) + @classmethod + def templates_path(cls) -> str: + try: + import comfyui_workflow_templates + + return str( + importlib.resources.files(comfyui_workflow_templates) / "templates" + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/requirements.txt b/requirements.txt index 851db23bd..278e3eaa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -comfyui-frontend-package==1.15.13 +comfyui-frontend-package==1.16.8 +comfyui-workflow-templates==0.1.1 torch torchsde torchvision diff --git a/server.py b/server.py index 62667ce18..0cc97b248 100644 --- a/server.py +++ b/server.py @@ -736,6 +736,12 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + workflow_templates_path = FrontendManager.templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) From 93292bc450dd291925c45adea00ebedb8a3209ef Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 14:45:01 -0400 Subject: [PATCH 14/27] ComfyUI version 0.3.29 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index a44538d1a..f9161b37e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.28" +__version__ = "0.3.29" diff --git a/pyproject.toml b/pyproject.toml index 6eb1704db..e8fc9555d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.28" +version = "0.3.29" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 19373aee759be2f0868a69603c5d967e5e63e1c5 Mon Sep 17 00:00:00 2001 From: BVH <82035780+bvhari@users.noreply.github.com> Date: Fri, 18 Apr 2025 00:54:33 +0530 Subject: [PATCH 15/27] Add FreSca node (#7631) --- comfy_extras/nodes_fresca.py | 102 +++++++++++++++++++++++++++++++++++ nodes.py | 3 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_fresca.py diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 000000000..b0b86f235 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,102 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # 2) Create a mask to scale frequencies differently + B, C, H, W = x_freq.shape + crow, ccol = H // 2, W // 2 + + # Initialize mask with high-frequency scaling factor + mask = torch.ones((B, C, H, W), device=device) * scale_high + + # Apply low-frequency scaling factor to center region + mask[ + ..., + crow - freq_cutoff : crow + freq_cutoff, + ccol - freq_cutoff : ccol + freq_cutoff, + ] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/nodes.py b/nodes.py index ae0a2e183..fce3dcb3b 100644 --- a/nodes.py +++ b/nodes.py @@ -2281,7 +2281,8 @@ def init_builtin_extra_nodes(): "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", - "nodes_hidream.py" + "nodes_hidream.py", + "nodes_fresca.py", ] import_failed = [] From 3dc240d08939bef67ed6e7308d6c68d6410bbfa5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 15:46:41 -0400 Subject: [PATCH 16/27] Make fresca work on multi dim. --- comfy_extras/nodes_fresca.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index b0b86f235..fa573299a 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -26,19 +26,17 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1)) - # 2) Create a mask to scale frequencies differently - B, C, H, W = x_freq.shape - crow, ccol = H // 2, W // 2 - # Initialize mask with high-frequency scaling factor - mask = torch.ones((B, C, H, W), device=device) * scale_high + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) # Apply low-frequency scaling factor to center region - mask[ - ..., - crow - freq_cutoff : crow + freq_cutoff, - ccol - freq_cutoff : ccol + freq_cutoff, - ] = scale_low + m[:] = scale_low # 3) Apply frequency-specific scaling x_freq = x_freq * mask From 880c205df1fca4491c78523eb52d1a388f89ef92 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 16:58:27 -0400 Subject: [PATCH 17/27] Add hidream to readme. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a99aca0e7..cf6df7e55 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) + - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) From 55822faa05293dce6039d504695a12124a3eb35f Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 17 Apr 2025 21:02:24 -0400 Subject: [PATCH 18/27] [Type] Annotate graph.get_input_info (#7386) * [Type] Annotate graph.get_input_info * nit * nit --- comfy_execution/graph.py | 24 +++++++++++++++++++++--- execution.py | 31 ++++++++++++++++--------------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 59b42b746..a2799b52e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,9 @@ -import nodes +from __future__ import annotations +from typing import Type, Literal +import nodes from comfy_execution.graph_utils import is_link +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions class DependencyCycleError(Exception): pass @@ -54,7 +57,22 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name, valid_inputs=None): +def get_input_info( + class_def: Type[ComfyNodeABC], + input_name: str, + valid_inputs: InputTypeDict | None = None +) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + """Get the input type, category, and extra info for a given input name. + + Arguments: + class_def: The class definition of the node. + input_name: The name of the input to get info for. + valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES(). + + Returns: + tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name. + """ + valid_inputs = valid_inputs or class_def.INPUT_TYPES() input_info = None input_category = None @@ -126,7 +144,7 @@ class TopologicalSort: from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) diff --git a/execution.py b/execution.py index 9a5e27771..d09102f55 100644 --- a/execution.py +++ b/execution.py @@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + _, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated): continue val = inputs[x] - info = (type_input, extra_info) + info = (input_type, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", @@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated): val = val["__value__"] inputs[x] = val - if type_input == "INT": + if input_type == "INT": val = int(val) inputs[x] = val - if type_input == "FLOAT": + if input_type == "FLOAT": val = float(val) inputs[x] = val - if type_input == "STRING": + if input_type == "STRING": val = str(val) inputs[x] = val - if type_input == "BOOLEAN": + if input_type == "BOOLEAN": val = bool(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", + "message": f"Failed to convert an input value to a {input_type} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, @@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if isinstance(type_input, list): - if val not in type_input: + if isinstance(input_type, list): + combo_options = input_type + if val not in combo_options: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" + if len(combo_options) > 20: + list_info = f"(list of length {len(combo_options)})" input_config = None else: - list_info = str(type_input) + list_info = str(combo_options) error = { "type": "value_not_in_list", From 34e06bf7ecb6bca4631b746da5af433098db92c7 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 18 Apr 2025 02:52:18 -0400 Subject: [PATCH 19/27] add support to output camera state (#7582) --- comfy_extras/nodes_load_3d.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index db30030fb..53d892bc4 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -21,8 +21,8 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -41,7 +41,7 @@ class Load3D(): normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, normal_image, lineart_image + return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'] class Load3DAnimation(): @classmethod @@ -59,8 +59,8 @@ class Load3DAnimation(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -77,13 +77,16 @@ class Load3DAnimation(): ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, normal_image + return output_image, output_mask, model_file, normal_image, image['camera_info'] class Preview3D(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -95,13 +98,22 @@ class Preview3D(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } class Preview3DAnimation(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -113,7 +125,13 @@ class Preview3DAnimation(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, From 2383a39e3baa11344b0a23b51e3c2f5deff0fc27 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 18 Apr 2025 08:53:36 +0200 Subject: [PATCH 20/27] Replace CLIPType if with getattr (#7589) * Replace CLIPType if with getattr * Forgot to remove breakpoint from testing --- nodes.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/nodes.py b/nodes.py index fce3dcb3b..d4082d19d 100644 --- a/nodes.py +++ b/nodes.py @@ -930,26 +930,7 @@ class CLIPLoader: DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" def load_clip(self, clip_name, type="stable_diffusion", device="default"): - if type == "stable_cascade": - clip_type = comfy.sd.CLIPType.STABLE_CASCADE - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "stable_audio": - clip_type = comfy.sd.CLIPType.STABLE_AUDIO - elif type == "mochi": - clip_type = comfy.sd.CLIPType.MOCHI - elif type == "ltxv": - clip_type = comfy.sd.CLIPType.LTXV - elif type == "pixart": - clip_type = comfy.sd.CLIPType.PIXART - elif type == "cosmos": - clip_type = comfy.sd.CLIPType.COSMOS - elif type == "lumina2": - clip_type = comfy.sd.CLIPType.LUMINA2 - elif type == "wan": - clip_type = comfy.sd.CLIPType.WAN - else: - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} if device == "cpu": @@ -977,16 +958,10 @@ class DualCLIPLoader: DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" def load_clip(self, clip_name1, clip_name2, type, device="default"): + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - if type == "sdxl": - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "flux": - clip_type = comfy.sd.CLIPType.FLUX - elif type == "hunyuan_video": - clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO model_options = {} if device == "cpu": From 7ecd5e961465d9bb20fb12b7068e1930da875b0e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 18 Apr 2025 03:16:16 -0400 Subject: [PATCH 21/27] Increase freq_cutoff in FreSca node. --- comfy_extras/nodes_fresca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index fa573299a..ee310c874 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -61,7 +61,7 @@ class FreSca: "tooltip": "Scaling factor for low-frequency components"}), "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, "tooltip": "Scaling factor for high-frequency components"}), - "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1, + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, "tooltip": "Number of frequency indices around center to consider as low-frequency"}), } } From f3b09b9f2d374518449d4e0211dbb21b95858eb5 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 18 Apr 2025 15:12:42 -0400 Subject: [PATCH 22/27] [BugFix] Update frontend to 1.16.9 (#7655) Backport https://github.com/Comfy-Org/ComfyUI_frontend/pull/3505 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 278e3eaa8..ff9f66a77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.16.8 +comfyui-frontend-package==1.16.9 comfyui-workflow-templates==0.1.1 torch torchsde From dc300a45698e5cb85f155b8fcb899b1df3c0f855 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Sat, 19 Apr 2025 12:21:46 -0700 Subject: [PATCH 23/27] Add wanfun template workflows. (#7678) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ff9f66a77..5c3a854ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.16.9 -comfyui-workflow-templates==0.1.1 +comfyui-workflow-templates==0.1.3 torch torchsde torchvision From 636d4bfb8994c7f123f15971af5d38a9754377ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Apr 2025 15:55:43 -0400 Subject: [PATCH 24/27] Fix hard crash when the spiece tokenizer path is bad. --- comfy/text_encoders/spiece_tokenizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 21df4f863..caccb3ca2 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,4 +1,5 @@ import torch +import os class SPieceTokenizer: @staticmethod @@ -15,6 +16,8 @@ class SPieceTokenizer: if isinstance(tokenizer_path, bytes): self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: + if not os.path.isfile(tokenizer_path): + raise ValueError("invalid tokenizer") self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): From 4486b0d0ff536b32100863f68e870ba18bf3d051 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Sat, 19 Apr 2025 14:23:31 -0700 Subject: [PATCH 25/27] Update CODEOWNERS and add christian-byrne (#7663) --- CODEOWNERS | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 72a59effe..013ea8622 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,20 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink +*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne # Python web server -/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata +/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne # Node developers -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered -/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered +/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne From f43e1d7f415374cea5bf7561d8e1278e1e52c95a Mon Sep 17 00:00:00 2001 From: power88 <741815398@qq.com> Date: Sun, 20 Apr 2025 07:47:30 +0800 Subject: [PATCH 26/27] Hidream: Allow loading hidream text encoders in CLIPLoader and DualCLIPLoader (#7676) * Hidream: Allow partial loading text encoders * reformat code for ruff check. --- comfy/sd.py | 34 ++++++++++++++++++++++++++++++++++ comfy/text_encoders/hidream.py | 4 ++++ nodes.py | 8 ++++---- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d97873ba2..8aba5d655 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -703,6 +703,7 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + HIDREAM = 14 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -791,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -811,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer @@ -827,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.LLAMA3_8: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: + # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -848,6 +864,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO: clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer + elif clip_type == CLIPType.HIDREAM: + # Detect + hidream_dualclip_classes = [] + for hidream_te in clip_data: + te_model = detect_te_model(hidream_te) + hidream_dualclip_classes.append(te_model) + + clip_l = TEModel.CLIP_L in hidream_dualclip_classes + clip_g = TEModel.CLIP_G in hidream_dualclip_classes + t5 = TEModel.T5_XXL in hidream_dualclip_classes + llama = TEModel.LLAMA3_8 in hidream_dualclip_classes + + # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} + llama_kwargs = llama_detect(clip_data) if llama else {} + + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index 6c34c5572..ca54eaa78 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -109,11 +109,15 @@ class HiDreamTEModel(torch.nn.Module): if self.t5xxl is not None: t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_out, t5_pooled = t5_output[:2] + else: + t5_out = None if self.llama is not None: ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) ll_out, ll_pooled = ll_output[:2] ll_out = ll_out[:, 1:] + else: + ll_out = None if t5_out is None: t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) diff --git a/nodes.py b/nodes.py index d4082d19d..b1ab62aad 100644 --- a/nodes.py +++ b/nodes.py @@ -917,7 +917,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -927,7 +927,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -945,7 +945,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -955,7 +955,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From fd274944418f1148b762a6e2d37efa820a569071 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Apr 2025 19:49:40 -0400 Subject: [PATCH 27/27] Use empty t5 of size 128 for hidream, seems to give closer results. --- comfy/text_encoders/hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index ca54eaa78..8e1abcfc1 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -120,7 +120,7 @@ class HiDreamTEModel(torch.nn.Module): ll_out = None if t5_out is None: - t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) + t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) if ll_out is None: ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())