missed adding these in previous commit

This commit is contained in:
envy-ai 2025-05-13 13:11:53 -04:00
parent 90f23bac28
commit 27d11db345
2 changed files with 17 additions and 10 deletions

View File

@ -5,6 +5,7 @@ from comfy import sdxl_clip
import comfy.model_management import comfy.model_management
import torch import torch
import logging import logging
import folder_paths
class HiDreamTokenizer: class HiDreamTokenizer:
@ -16,11 +17,11 @@ class HiDreamTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
@ -91,6 +92,8 @@ class HiDreamTEModel(torch.nn.Module):
token_weight_pairs_llama = token_weight_pairs["llama"] token_weight_pairs_llama = token_weight_pairs["llama"]
lg_out = None lg_out = None
pooled = None pooled = None
t5_out = None
ll_out = None
extra = {} extra = {}
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
@ -104,7 +107,8 @@ class HiDreamTEModel(torch.nn.Module):
else: else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1) if self.clip_g is not None and self.clip_l is not None:
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None: if self.t5xxl is not None:
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
@ -120,13 +124,15 @@ class HiDreamTEModel(torch.nn.Module):
ll_out = None ll_out = None
if t5_out is None: if t5_out is None:
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) t5_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "t5_blank.pt")
t5_out = torch.load(t5_path, map_location=comfy.model_management.intermediate_device())
if ll_out is None: if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
if pooled is None: if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) pooled_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "pooled_blank.pt")
pooled = torch.load(pooled_path, map_location=comfy.model_management.intermediate_device())
extra["conditioning_llama3"] = ll_out extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra return t5_out, pooled, extra

View File

@ -4,7 +4,7 @@ import os
import time import time
import mimetypes import mimetypes
import logging import logging
from typing import Literal from typing import Literal, List
from collections.abc import Collection from collections.abc import Collection
from comfy.cli_args import args from comfy.cli_args import args
@ -45,6 +45,7 @@ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetwo
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["hidream_empty_latents"] = ([os.path.join(models_dir, "hidream_empty_latents")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
@ -141,7 +142,7 @@ def get_directory_by_type(type_name: str) -> str | None:
return get_input_directory() return get_input_directory()
return None return None
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]:
""" """
Example: Example:
files = os.listdir(folder_paths.get_input_directory()) files = os.listdir(folder_paths.get_input_directory())