diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index 8e1abcfc1..4f0b145f1 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -5,6 +5,7 @@ from comfy import sdxl_clip import comfy.model_management import torch import logging +import folder_paths class HiDreamTokenizer: @@ -16,11 +17,11 @@ class HiDreamTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens - out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): @@ -91,6 +92,8 @@ class HiDreamTEModel(torch.nn.Module): token_weight_pairs_llama = token_weight_pairs["llama"] lg_out = None pooled = None + t5_out = None + ll_out = None extra = {} if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: @@ -104,8 +107,9 @@ class HiDreamTEModel(torch.nn.Module): else: g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) - pooled = torch.cat((l_pooled, g_pooled), dim=-1) - + if self.clip_g is not None and self.clip_l is not None: + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + if self.t5xxl is not None: t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_out, t5_pooled = t5_output[:2] @@ -120,13 +124,15 @@ class HiDreamTEModel(torch.nn.Module): ll_out = None if t5_out is None: - t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) + t5_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "t5_blank.pt") + t5_out = torch.load(t5_path, map_location=comfy.model_management.intermediate_device()) if ll_out is None: ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) if pooled is None: - pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) + pooled_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "pooled_blank.pt") + pooled = torch.load(pooled_path, map_location=comfy.model_management.intermediate_device()) extra["conditioning_llama3"] = ll_out return t5_out, pooled, extra diff --git a/folder_paths.py b/folder_paths.py index 9a525e5a1..3f6bc03cd 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -4,7 +4,7 @@ import os import time import mimetypes import logging -from typing import Literal +from typing import Literal, List from collections.abc import Collection from comfy.cli_args import args @@ -45,6 +45,7 @@ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetwo folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["hidream_empty_latents"] = ([os.path.join(models_dir, "hidream_empty_latents")], supported_pt_extensions) output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") @@ -141,7 +142,7 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory())