mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
missed adding these in previous commit
This commit is contained in:
parent
90f23bac28
commit
27d11db345
@ -5,6 +5,7 @@ from comfy import sdxl_clip
|
||||
import comfy.model_management
|
||||
import torch
|
||||
import logging
|
||||
import folder_paths
|
||||
|
||||
|
||||
class HiDreamTokenizer:
|
||||
@ -16,11 +17,11 @@ class HiDreamTokenizer:
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids)
|
||||
out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
@ -91,6 +92,8 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
lg_out = None
|
||||
pooled = None
|
||||
t5_out = None
|
||||
ll_out = None
|
||||
extra = {}
|
||||
|
||||
if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
|
||||
@ -104,7 +107,8 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
else:
|
||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
||||
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
if self.clip_g is not None and self.clip_l is not None:
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
if self.t5xxl is not None:
|
||||
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
|
||||
@ -120,13 +124,15 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
ll_out = None
|
||||
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
||||
t5_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "t5_blank.pt")
|
||||
t5_out = torch.load(t5_path, map_location=comfy.model_management.intermediate_device())
|
||||
|
||||
if ll_out is None:
|
||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
||||
pooled_path = folder_paths.get_full_path_or_raise("hidream_empty_latents", "pooled_blank.pt")
|
||||
pooled = torch.load(pooled_path, map_location=comfy.model_management.intermediate_device())
|
||||
|
||||
extra["conditioning_llama3"] = ll_out
|
||||
return t5_out, pooled, extra
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
import mimetypes
|
||||
import logging
|
||||
from typing import Literal
|
||||
from typing import Literal, List
|
||||
from collections.abc import Collection
|
||||
|
||||
from comfy.cli_args import args
|
||||
@ -45,6 +45,7 @@ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetwo
|
||||
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
folder_names_and_paths["hidream_empty_latents"] = ([os.path.join(models_dir, "hidream_empty_latents")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
@ -141,7 +142,7 @@ def get_directory_by_type(type_name: str) -> str | None:
|
||||
return get_input_directory()
|
||||
return None
|
||||
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
|
||||
def filter_files_content_types(files: list[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> list[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user