LLM support in ComfyUI

- Currently uses `transformers`
 - Supports model management and correctly loading and unloading models
   based on what your machine can support
 - Includes a Text Diffusers 2 workflow to demonstrate text rendering in
   SD1.5
This commit is contained in:
doctorpangloss 2024-05-14 17:30:23 -07:00
parent 0ee2f3bf15
commit 8741cb3ce8
20 changed files with 893 additions and 213 deletions

View File

@ -87,9 +87,10 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
vocab_size = config_dict["vocab_size"]
super().__init__() super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)

View File

@ -1,14 +1,18 @@
from __future__ import annotations
import os import os
from . import sd, utils from . import sd, utils
def first_file(path, filenames):
def first_file(path, filenames) -> str | None:
for f in filenames: for f in filenames:
p = os.path.join(path, f) p = os.path.join(path, f)
if os.path.exists(p): if os.path.exists(p):
return p return str(p)
return None return None
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None): def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"] diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
@ -22,15 +26,17 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None: if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path) text_encoder_paths.append(text_encoder2_path)
unet = sd.load_unet(unet_path) if unet_path is not None:
unet = sd.load_unet(unet_path)
clip = None clip = None
if output_clip: textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory) if output_clip and not all(te is None for te in text_encoder_paths):
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config1)
vae = None vae = None
if output_vae: if output_vae and vae_path is not None:
_sd = utils.load_torch_file(vae_path) _sd = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=_sd) vae = sd.VAE(sd=_sd)
return (unet, clip, vae) return unet, clip, vae

View File

@ -0,0 +1,5 @@
from fastchat.model.model_adapter import register_model_adapter
from .fastchat_adapters import Phi3Adapter
register_model_adapter(Phi3Adapter)

View File

@ -0,0 +1,62 @@
from __future__ import annotations
from typing import Optional
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
class Phi3Adapter(BaseModelAdapter):
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
def match(self, model_path: str):
return "phi-3-mini-128k-instruct" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
self.model = model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
if input:
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
else:
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
return prompt
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
prompt = self.generate_prompt(messages[-1]["content"])
for i in range(len(messages) - 2, -1, -1):
if messages[i]["role"] == "user":
prompt = self.generate_prompt(messages[i]["content"]) + prompt
elif messages[i]["role"] == "assistant":
prompt = messages[i]["content"] + prompt
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(
input_ids,
**generation_kwargs
)
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output = output.replace(prompt, "").strip()
return output
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phi-3-mini")

View File

@ -0,0 +1,8 @@
from __future__ import annotations
from typing import NamedTuple, Dict, Any
class ProcArgsRes(NamedTuple):
seed: int
generate_kwargs: Dict[str, Any]

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import warnings
from typing import Optional, Any
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from ..model_management import unet_offload_device, get_torch_device
from ..model_management_types import ModelManageable
class TransformersManagedModel(ModelManageable):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None):
self.repo_id = repo_id
self.model = model
self.tokenizer = tokenizer
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
if model.device != self.offload_device:
model.to(device=self.offload_device)
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@property
def current_device(self) -> torch.device:
return self.model.device
def is_clone(self, other: Any) -> bool:
return hasattr(other, "model") and self.model is other.model
def clone_has_same_weights(self, clone: Any) -> bool:
if not isinstance(clone, TransformersManagedModel):
return False
clone: TransformersManagedModel
if not self.is_clone(clone):
return False
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
def model_size(self) -> int:
return self._size
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(arg)
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=offload_device)

View File

@ -7,7 +7,7 @@ from os.path import join
from typing import List, Any, Optional, Union from typing import List, Any, Optional, Union
import tqdm import tqdm
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, scan_cache_dir
from requests import Session from requests import Session
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
@ -167,6 +167,7 @@ KNOWN_CHECKPOINTS = [
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"), CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"), CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"), CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
] ]
KNOWN_UNCLIP_CHECKPOINTS = [ KNOWN_UNCLIP_CHECKPOINTS = [
@ -297,6 +298,12 @@ KNOWN_VAES = [
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"), HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
] ]
KNOWN_HUGGINGFACE_MODEL_REPOS = {
"JingyeChen22/textdiffuser2_layout_planner",
'JingyeChen22/textdiffuser2-full-ft',
"microsoft/Phi-3-mini-4k-instruct",
}
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models: if args.disable_known_models:
@ -304,3 +311,10 @@ def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile
symbol += models symbol += models
folder_paths.invalidate_cache(folder_name) folder_paths.invalidate_cache(folder_name)
return symbol return symbol
def huggingface_repos() -> List[str]:
cache_info = scan_cache_dir()
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
return list(existing_repo_ids | known_repo_ids)

View File

@ -1,34 +1,36 @@
from __future__ import annotations from __future__ import annotations
import logging
import sys
from enum import Enum
from threading import RLock
from typing import Literal from typing import Literal
import psutil import psutil
import logging
from enum import Enum
from .cli_args import args
from . import interruption
from threading import RLock
import torch import torch
import sys
from . import interruption
from .cli_args import args
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
model_management_lock = RLock() model_management_lock = RLock()
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 # No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram NO_VRAM = 1 # Very low vram: enable all the options to save vram
LOW_VRAM = 2 LOW_VRAM = 2
NORMAL_VRAM = 3 NORMAL_VRAM = 3
HIGH_VRAM = 4 HIGH_VRAM = 4
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
class CPUState(Enum): class CPUState(Enum):
GPU = 0 GPU = 0
CPU = 1 CPU = 1
MPS = 2 MPS = 2
# Determine VRAM State # Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM
@ -46,6 +48,7 @@ if args.deterministic:
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml
directml_enabled = True directml_enabled = True
device_index = args.directml device_index = args.directml
if device_index < 0: if device_index < 0:
@ -54,10 +57,11 @@ if args.directml is not None:
directml_device = torch_directml.device(device_index) directml_device = torch_directml.device(device_index)
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index))) logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available = True xpu_available = True
except: except:
@ -73,6 +77,7 @@ except:
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
def is_intel_xpu(): def is_intel_xpu():
global cpu_state global cpu_state
global xpu_available global xpu_available
@ -81,6 +86,7 @@ def is_intel_xpu():
return True return True
return False return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -97,6 +103,7 @@ def get_torch_device():
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_enabled
if dev is None: if dev is None:
@ -107,7 +114,7 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total_torch = mem_total mem_total_torch = mem_total
else: else:
if directml_enabled: if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
@ -126,6 +133,7 @@ def get_total_memory(dev=None, torch_total_too=False):
else: else:
return mem_total return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
@ -147,6 +155,7 @@ else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try: try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
@ -164,6 +173,7 @@ else:
except: except:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
def is_nvidia(): def is_nvidia():
global cpu_state global cpu_state
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
@ -171,6 +181,7 @@ def is_nvidia():
return True return True
return False return False
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention: if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -205,7 +216,6 @@ elif args.bf16_vae:
elif args.fp32_vae: elif args.fp32_vae:
VAE_DTYPE = torch.float32 VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
@ -233,7 +243,6 @@ if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to vram_state = set_vram_to
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED vram_state = VRAMState.DISABLED
@ -247,6 +256,7 @@ DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management") logging.info("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if device.type == "cuda": if device.type == "cuda":
@ -262,6 +272,7 @@ def get_torch_device_name(device):
else: else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
@ -271,6 +282,7 @@ logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = [] current_loaded_models = []
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
sd = module.state_dict() sd = module.state_dict()
@ -279,6 +291,7 @@ def module_size(module):
module_mem += t.nelement() * t.element_size() module_mem += t.nelement() * t.element_size()
return module_mem return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model: ModelManageable): def __init__(self, model: ModelManageable):
self.model = model self.model = model
@ -328,9 +341,11 @@ class LoadedModel:
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]: def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock: with model_management_lock:
to_unload = [] to_unload = []
@ -361,12 +376,13 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
return unload_weight return unload_weight
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock: with model_management_lock:
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
@ -391,6 +407,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache() soft_empty_cache()
def load_models_gpu(models, memory_required=0): def load_models_gpu(models, memory_required=0):
global vram_state global vram_state
@ -424,7 +441,7 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
@ -432,7 +449,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None: if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded loaded_model.weights_loaded = not weights_unloaded
@ -447,8 +464,8 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): # only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
@ -465,6 +482,7 @@ def load_model_gpu(model):
with model_management_lock: with model_management_lock:
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(keep_clone_weights_loaded=False): def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock: with model_management_lock:
to_delete = [] to_delete = []
@ -472,8 +490,8 @@ def cleanup_models(keep_clone_weights_loaded=False):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
if not keep_clone_weights_loaded: if not keep_clone_weights_loaded:
to_delete = [i] + to_delete to_delete = [i] + to_delete
#TODO: find a less fragile way to do this. # TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model
to_delete = [i] + to_delete to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:
@ -481,6 +499,7 @@ def cleanup_models(keep_clone_weights_loaded=False):
x.model_unload() x.model_unload()
del x del x
def dtype_size(dtype): def dtype_size(dtype):
dtype_size = 4 dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16: if dtype == torch.float16 or dtype == torch.bfloat16:
@ -490,17 +509,19 @@ def dtype_size(dtype):
else: else:
try: try:
dtype_size = dtype.itemsize dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize except: # Old pytorch doesn't have .itemsize
pass pass
return dtype_size return dtype_size
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
def unet_initial_load_device(parameters, dtype):
torch_dev = get_torch_device() torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
return torch_dev return torch_dev
@ -518,7 +539,8 @@ def unet_inital_load_device(parameters, dtype):
else: else:
return cpu_dev return cpu_dev
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:
@ -535,8 +557,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.bfloat16 return torch.bfloat16
return torch.float32 return torch.float32
# None means no manual cast # None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if weight_dtype == torch.float32: if weight_dtype == torch.float32:
return None return None
@ -556,12 +579,14 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
else: else:
return torch.float32 return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_device(): def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -573,6 +598,7 @@ def text_encoder_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_dtype(device=None): def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc: if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn return torch.float8_e4m3fn
@ -595,27 +621,32 @@ def intermediate_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def vae_device(): def vae_device():
if args.cpu_vae: if args.cpu_vae:
return torch.device("cpu") return torch.device("cpu")
return get_torch_device() return get_torch_device()
def vae_offload_device(): def vae_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(): def vae_dtype():
global VAE_DTYPE global VAE_DTYPE
return VAE_DTYPE return VAE_DTYPE
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type
return "cuda" return "cuda"
def supports_dtype(device, dtype): #TODO
def supports_dtype(device, dtype): # TODO
if dtype == torch.float32: if dtype == torch.float32:
return True return True
if is_device_cpu(device): if is_device_cpu(device):
@ -626,12 +657,14 @@ def supports_dtype(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
return False return False
# return True #TODO: figure out why this causes issues # return True #TODO: figure out why this causes issues
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
with model_management_lock: with model_management_lock:
device_supports_cast = False device_supports_cast = False
@ -655,6 +688,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
else: else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -674,18 +708,21 @@ def xformers_enabled_vae():
return XFORMERS_ENABLED_VAE return XFORMERS_ENABLED_VAE
def pytorch_attention_enabled(): def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_flash_attention(): def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention? # TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia if is_nvidia(): # pytorch flash attention only works on Nvidia
return True return True
return False return False
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled global directml_enabled
if dev is None: if dev is None:
@ -696,7 +733,7 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if directml_enabled: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) stats = torch.xpu.memory_stats(dev)
@ -718,29 +755,36 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def cpu_mode(): def cpu_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.CPU return cpu_state == CPUState.CPU
def mps_mode(): def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def is_device_type(device, type): def is_device_type(device, type):
if hasattr(device, 'type'): if hasattr(device, 'type'):
if (device.type == type): if (device.type == type):
return True return True
return False return False
def is_device_cpu(device): def is_device_cpu(device):
return is_device_type(device, 'cpu') return is_device_type(device, 'cpu')
def is_device_mps(device): def is_device_mps(device):
return is_device_type(device, 'mps') return is_device_type(device, 'mps')
def is_device_cuda(device): def is_device_cuda(device):
return is_device_type(device, 'cuda') return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_enabled
@ -781,9 +825,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
fp16_works = False fp16_works = False
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card # when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior # TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series: for x in nvidia_10_series:
if x in props.name.lower(): if x in props.name.lower():
@ -797,7 +841,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if props.major < 7: if props.major < 7:
return False return False
#FP16 is just broken on these cards # FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
@ -805,12 +849,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True return True
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None: if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
return False return False
if device is not None: #TODO not sure about mps bf16 support if device is not None: # TODO not sure about mps bf16 support
if is_device_mps(device): if is_device_mps(device):
return False return False
@ -842,6 +887,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
with model_management_lock: with model_management_lock:
global cpu_state global cpu_state
@ -850,16 +896,17 @@ def soft_empty_cache(force=False):
elif is_intel_xpu(): elif is_intel_xpu():
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def unload_all_models(): def unload_all_models():
with model_management_lock: with model_management_lock:
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove def resolve_lowvram_weight(weight, model, key): # TODO: remove
return weight return weight

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Protocol, Optional from typing import Protocol, Optional, Any
import torch import torch
@ -18,13 +18,12 @@ class ModelManageable(Protocol):
load_device: torch.device load_device: torch.device
offload_device: torch.device offload_device: torch.device
model: torch.nn.Module model: torch.nn.Module
current_device: torch.device
@property @property
def dtype(self) -> torch.dtype: def current_device(self) -> torch.device:
... ...
def is_clone(self, other: torch.nn.Module) -> bool: def is_clone(self, other: Any) -> bool:
... ...
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool: def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:

View File

@ -1,11 +1,12 @@
import torch
import copy import copy
import inspect import inspect
import logging import logging
import uuid import uuid
from . import utils import torch
from . import model_management from . import model_management
from . import utils
from .model_management_types import ModelManageable from .model_management_types import ModelManageable
@ -20,6 +21,7 @@ def apply_weight_decompose(dora_scale, weight):
return weight * (dora_scale / weight_norm) return weight * (dora_scale / weight_norm)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy() to = model_options["transformer_options"].copy()
@ -41,6 +43,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
model_options["transformer_options"] = to model_options["transformer_options"] = to
return model_options return model_options
class ModelPatcher(ModelManageable): class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
@ -49,14 +52,15 @@ class ModelPatcher(ModelManageable):
self.backup = {} self.backup = {}
self.object_patches = {} self.object_patches = {}
self.object_patches_backup = {} self.object_patches_backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options": {}}
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self._current_device: torch.device
if current_device is None: if current_device is None:
self.current_device = self.offload_device self._current_device = self.offload_device
else: else:
self.current_device = current_device self._current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
@ -71,7 +75,7 @@ class ModelPatcher(ModelManageable):
return self.size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -107,7 +111,7 @@ class ModelPatcher(ModelManageable):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization: if disable_cfg1_optimization:
@ -270,18 +274,20 @@ class ModelPatcher(ModelManageable):
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.current_device = device_to self._current_device = device_to
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0): def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False) self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
class LowVramPatch: class LowVramPatch:
def __init__(self, key, model_patcher): def __init__(self, key, model_patcher):
self.key = key self.key = key
self.model_patcher = model_patcher self.model_patcher = model_patcher
def __call__(self, weight): def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
@ -325,7 +331,7 @@ class ModelPatcher(ModelManageable):
weight *= strength_model weight *= strength_model
if isinstance(v, list): if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), ) v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 1: if len(v) == 1:
patch_type = "diff" patch_type = "diff"
@ -340,14 +346,14 @@ class ModelPatcher(ModelManageable):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else: else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype) weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon elif patch_type == "lora": # lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4] dora_scale = v[4]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / mat2.shape[0] alpha *= v[2] / mat2.shape[0]
if v[3] is not None: if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it # locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32) mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
@ -407,7 +413,7 @@ class ModelPatcher(ModelManageable):
w2a = v[3] w2a = v[3]
w2b = v[4] w2b = v[4]
dora_scale = v[7] dora_scale = v[7]
if v[5] is not None: #cp decomposition if v[5] is not None: # cp decomposition
t1 = v[5] t1 = v[5]
t2 = v[6] t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l', m1 = torch.einsum('i j k l, j r, i p -> p r k l',
@ -478,10 +484,14 @@ class ModelPatcher(ModelManageable):
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.current_device = device_to self._current_device = value = device_to
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
utils.set_attr(self.model, k, self.object_patches_backup[k]) utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear() self.object_patches_backup.clear()
@property
def current_device(self) -> torch.device:
return self._current_device

View File

@ -9,6 +9,7 @@ import logging
from PIL import Image, ImageOps, ImageSequence, ImageFile from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from huggingface_hub import hf_hub_download, snapshot_download
from natsort import natsorted from natsort import natsorted
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -25,11 +26,13 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context from ..execution_context import current_execution_context
from ..images import open_image from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
from .. import controlnet from .. import controlnet
from ..open_exr import load_exr from ..open_exr import load_exr
from .. import node_helpers from .. import node_helpers
from ..utils import comfy_tqdm
class CLIPTextEncode: class CLIPTextEncode:
@classmethod @classmethod
@ -513,11 +516,14 @@ class DiffusersLoader:
if "model_index.json" in files: if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path)) paths.append(os.path.relpath(root, start=search_path))
paths += huggingface_repos()
paths = list(frozenset(paths))
return {"required": {"model_path": (paths,), }} return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders/deprecated" CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True): def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"): for search_path in folder_paths.get_folder_paths("diffusers"):
@ -526,6 +532,9 @@ class DiffusersLoader:
if os.path.exists(path): if os.path.exists(path):
model_path = path model_path = path
break break
if not os.path.exists(model_path):
with comfy_tqdm():
model_path = snapshot_download(model_path)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))

View File

@ -1,31 +1,32 @@
import torch from __future__ import annotations
from enum import Enum
import dataclasses
import logging import logging
from enum import Enum
from typing import Any, Optional
from . import model_management import torch
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
import yaml import yaml
from . import utils
from . import clip_vision from . import clip_vision
from . import gligen
from . import diffusers_convert from . import diffusers_convert
from . import gligen
from . import lora
from . import model_detection from . import model_detection
from . import model_management
from . import model_patcher
from . import model_sampling
from . import sd1_clip from . import sd1_clip
from . import sd2_clip from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
from . import utils
from . import model_patcher from .ldm.cascade.stage_a import StageA
from . import model_sampling from .ldm.cascade.stage_c_coder import StageC_coder
from . import lora from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .t2i_adapter import adapter from .t2i_adapter import adapter
from .taesd import taesd from .taesd import taesd
def load_model_weights(model, sd): def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
m = set(m) m = set(m)
@ -40,6 +41,7 @@ def load_model_weights(model, sd):
logging.warning("missing {}".format(m)) logging.warning("missing {}".format(m))
return model return model
def load_clip_weights(model, sd): def load_clip_weights(model, sd):
k = list(sd.keys()) k = list(sd.keys())
for x in k: for x in k:
@ -87,7 +89,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False): def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@ -98,10 +100,12 @@ class CLIP:
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device) params['dtype'] = model_management.text_encoder_dtype(load_device)
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None self.layer_idx = None
@ -157,12 +161,13 @@ class CLIP:
def get_key_patches(self): def get_key_patches(self):
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None): def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) # These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
@ -181,16 +186,16 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = taesd.TAESD() self.first_stage_model = taesd.TAESD()
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: # VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()
self.downscale_ratio = 4 self.downscale_ratio = 4
self.upscale_ratio = 4 self.upscale_ratio = 4
#TODO # TODO
#self.memory_used_encode # self.memory_used_encode
#self.memory_used_decode # self.memory_used_decode
self.process_input = lambda image: image self.process_input = lambda image: image
self.process_output = lambda image: image self.process_output = lambda image: image
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: # effnet: encoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
@ -198,22 +203,22 @@ class VAE:
for k in sd: for k in sd:
new_sd["encoder.{}".format(k)] = sd[k] new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd sd = new_sd
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade elif "blocks.11.num_batches_tracked" in sd: # previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.latent_channels = 16 self.latent_channels = 16
new_sd = {} new_sd = {}
for k in sd: for k in sd:
new_sd["previewer.{}".format(k)] = sd[k] new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd sd = new_sd
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: # combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters # default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: # Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4] ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4 self.downscale_ratio = 4
self.upscale_ratio = 4 self.upscale_ratio = 4
@ -261,7 +266,7 @@ class VAE:
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
@ -269,22 +274,22 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output( output = self.process_output(
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0) / 3.0)
return output return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps) pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0 samples /= 3.0
return samples return samples
@ -298,23 +303,23 @@ class VAE:
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1) return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1, 1)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@ -323,8 +328,8 @@ class VAE:
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -332,16 +337,17 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
return samples return samples
def get_sd(self): def get_sd(self):
return self.first_stage_model.state_dict() return self.first_stage_model.state_dict()
class StyleModel: class StyleModel:
def __init__(self, model, device="cpu"): def __init__(self, model, device="cpu"):
self.model = model self.model = model
@ -360,26 +366,33 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data) model.load_state_dict(model_data)
return StyleModel(model) return StyleModel(model)
class CLIPType(Enum): class CLIPType(Enum):
STABLE_DIFFUSION = 1 STABLE_DIFFUSION = 1
STABLE_CASCADE = 2 STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
@dataclasses.dataclass
class CLIPTarget:
clip: Optional[Any] = None
vae: Optional[Any] = None
params: Optional[dict] = dataclasses.field(default_factory=dict)
tokenizer: Optional[Any] = None
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True)) clip_data.append(utils.load_torch_file(p, safe_load=True))
class EmptyClass:
pass
for i in range(len(clip_data)): for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = utils.clip_text_transformers_convert(clip_data[i], "", "") clip_data[i] = utils.clip_text_transformers_convert(clip_data[i], "", "")
else: else:
if "text_projection" in clip_data[i]: if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) # old models saved with the CLIPSave node
clip_target = EmptyClass() clip_target = CLIPTarget()
clip_target.params = {} clip_target.params = {}
if len(clip_data) == 1: if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
@ -399,7 +412,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@ -409,6 +422,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
logging.debug("clip unexpected: {}".format(u)) logging.debug("clip unexpected: {}".format(u))
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path, safe_load=True) data = utils.load_torch_file(ckpt_path, safe_load=True)
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
@ -416,10 +430,11 @@ def load_gligen(ckpt_path):
model = model.half() model = model.half()
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.") logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True) model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
#TODO: this function is a mess and should be removed eventually # TODO: this function is a mess and should be removed eventually
if config is None: if config is None:
with open(config_path, 'r') as stream: with open(config_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
@ -430,8 +445,10 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "parameterization" in model_config_params: if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v": if model_config_params["parameterization"] == "v":
m = model.clone() m = model.clone()
class ModelSamplingAdvanced(model_sampling.ModelSamplingDiscrete, model_sampling.V_PREDICTION): class ModelSamplingAdvanced(model_sampling.ModelSamplingDiscrete, model_sampling.V_PREDICTION):
pass pass
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config)) m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
model = m model = m
@ -441,6 +458,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = utils.load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
@ -467,7 +485,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_initial_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
@ -509,18 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd): #load unet in diffusers format def load_unet_state_dict(sd): # load unet in diffusers format
parameters = utils.calculate_parameters(sd) parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None: if model_config is None:
return None return None
new_sd = sd new_sd = sd
else: #diffusers else: # diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd) model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None: if model_config is None:
return None return None
@ -546,6 +564,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
logging.info("left over keys in unet: {}".format(left_over)) logging.info("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path): def load_unet(unet_path):
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd) model = load_unet_state_dict(sd)
@ -554,6 +573,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None clip_sd = None
load_models = [model] load_models = [model]

View File

@ -1,15 +1,45 @@
import os from __future__ import annotations
from transformers import CLIPTokenizer import importlib.resources as resources
from . import ops
import torch
import traceback
import zipfile
from . import model_management
from pkg_resources import resource_filename
from . import clip_model
import json import json
import logging import logging
import os
import traceback
import zipfile
from typing import List
import torch
from pkg_resources import resource_filename
from transformers import CLIPTokenizer
from . import clip_model
from . import model_management
from . import ops
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
config: dict | None = None
if text_model_config_or_path is None:
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
if isinstance(text_model_config_or_path, str):
if text_model_config_or_path.startswith("{"):
config = json.loads(text_model_config_or_path)
else:
if not os.path.exists(text_model_config_or_path):
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(text_model_config_or_path) as f:
config = json.load(f)
elif isinstance(text_model_config_or_path, dict):
config = text_model_config_or_path
assert config is not None
return config
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -23,6 +53,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output)) output += [pad_token] * (length - len(output))
return output return output
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
to_encode = list() to_encode = list()
@ -46,7 +77,7 @@ class ClipTokenWeightEncoder:
output = [] output = []
for k in range(0, sections): for k in range(0, sections):
z = out[k:k+1] z = out[k:k + 1]
if has_weights: if has_weights:
z_empty = out[-1] z_empty = out[-1]
for i in range(len(z)): for i in range(len(z)):
@ -60,6 +91,7 @@ class ClipTokenWeightEncoder:
return out[-1:].to(model_management.intermediate_device()), first_pooled return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [ LAYERS = [
@ -67,20 +99,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"pooled", "pooled",
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32 special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
super().__init__() super().__init__()
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS assert layer in self.LAYERS
if textmodel_json_config is None: config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
with open(textmodel_json_config) as f:
config = json.load(f)
self.transformer = model_class(config, dtype, device, ops.manual_cast) self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
@ -105,7 +133,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
#self.train = disabled_train # self.train = disabled_train
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
@ -132,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp = [] tokens_temp = []
for y in x: for y in x:
if isinstance(y, int): if isinstance(y, int):
if y == token_dict_size: #EOS token if y == token_dict_size: # EOS token
y = -1 y = -1
tokens_temp += [y] tokens_temp += [y]
else: else:
@ -153,12 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for x in embedding_weights: for x in embedding_weights:
new_embedding.weight[n] = x new_embedding.weight[n] = x
n += 1 n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding
self.transformer.set_input_embeddings(new_embedding) self.transformer.set_input_embeddings(new_embedding)
processed_tokens = [] processed_tokens = []
for x in out_tokens: for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one
return processed_tokens return processed_tokens
@ -201,6 +229,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd): def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False) return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string): def parse_parentheses(string):
result = [] result = []
current_item = "" current_item = ""
@ -229,6 +258,7 @@ def parse_parentheses(string):
result.append(current_item) result.append(current_item)
return result return result
def token_weights(string, current_weight): def token_weights(string, current_weight):
a = parse_parentheses(string) a = parse_parentheses(string)
out = [] out = []
@ -240,7 +270,7 @@ def token_weights(string, current_weight):
weight *= 1.1 weight *= 1.1
if xx > 0: if xx > 0:
try: try:
weight = float(x[xx+1:]) weight = float(x[xx + 1:])
x = x[:xx] x = x[:xx]
except: except:
pass pass
@ -249,16 +279,19 @@ def token_weights(string, current_weight):
out += [(x, current_weight)] out += [(x, current_weight)]
return out return out
def escape_important(text): def escape_important(text):
text = text.replace("\\)", "\0\1") text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2") text = text.replace("\\(", "\0\2")
return text return text
def unescape_important(text): def unescape_important(text):
text = text.replace("\0\1", ")") text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(") text = text.replace("\0\2", "(")
return text return text
def safe_load_embed_zip(embed_path): def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip: with zipfile.ZipFile(embed_path) as myzip:
names = list(filter(lambda a: "data/" in a, myzip.namelist())) names = list(filter(lambda a: "data/" in a, myzip.namelist()))
@ -267,17 +300,18 @@ def safe_load_embed_zip(embed_path):
with myzip.open(n) as myfile: with myzip.open(n) as myfile:
data = myfile.read() data = myfile.read()
number = len(data) // 4 number = len(data) // 4
length_embed = 1024 #sd2.x length_embed = 1024 # sd2.x
if number < 768: if number < 768:
continue continue
if number % 768 == 0: if number % 768 == 0:
length_embed = 768 #sd1.x length_embed = 768 # sd1.x
num_embeds = number // length_embed num_embeds = number // length_embed
embed = torch.frombuffer(data, dtype=torch.float) embed = torch.frombuffer(data, dtype=torch.float)
out = embed.reshape((num_embeds, length_embed)).clone() out = embed.reshape((num_embeds, length_embed)).clone()
del embed del embed
return out return out
def expand_directory_list(directories): def expand_directory_list(directories):
dirs = set() dirs = set()
for x in directories: for x in directories:
@ -286,6 +320,7 @@ def expand_directory_list(directories):
dirs.add(root) dirs.add(root)
return list(dirs) return list(dirs)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
@ -356,6 +391,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values)) embed_out = next(iter(values))
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None: if tokenizer_path is None:
@ -378,16 +414,20 @@ class SDTokenizer:
self.end_token = empty[0] self.end_token = empty[0]
self.pad_with_end = pad_with_end self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length self.pad_to_max_length = pad_to_max_length
self.add_tokens([])
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
self.max_word_length = 8 self.max_word_length = 8
self.embedding_identifier = "embedding:" self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.embedding_key = embedding_key self.embedding_key = embedding_key
def _try_get_embedding(self, embedding_name:str): def add_tokens(self, tokens: List[str]):
if len(tokens) > 0:
self.tokenizer.add_tokens(tokens)
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
def _try_get_embedding(self, embedding_name: str):
''' '''
Takes a potential embedding name and tries to retrieve it. Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
@ -400,8 +440,7 @@ class SDTokenizer:
return (embed, embedding_name[len(stripped):]) return (embed, embedding_name[len(stripped):])
return (embed, "") return (embed, "")
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text:str, return_word_ids=False):
''' '''
Takes a prompt and converts it to a list of (token, weight, word id) elements. Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors. Tokens can both be integer tokens and pre computed CLIP tensors.
@ -417,13 +456,13 @@ class SDTokenizer:
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
vocab = self.tokenizer.get_vocab() vocab = self.tokenizer.get_vocab()
#tokenize words # tokenize words
tokens = [] tokens = []
for weighted_segment, weight in parsed_weights: for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""] to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize: for word in to_tokenize:
#if we find an embedding, deal with the embedding # if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n') embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name) embed, leftover = self._try_get_embedding(embedding_name)
@ -434,52 +473,54 @@ class SDTokenizer:
tokens.append([(embed, weight)]) tokens.append([(embed, weight)])
else: else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word # if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "": if leftover != "":
word = leftover word = leftover
else: else:
continue continue
#parse word # parse word
exact_word = f"{word}</w>" exact_word = f"{word}</w>"
if exact_word in vocab: if word == self.tokenizer.eos_token:
tokenizer_result = [self.tokenizer.eos_token_id]
elif exact_word in vocab:
tokenizer_result = [vocab[exact_word]] tokenizer_result = [vocab[exact_word]]
else: else:
tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:-1] tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:-1]
tokens.append([(t, weight) for t in tokenizer_result]) tokens.append([(t, weight) for t in tokenizer_result])
#reshape token array to CLIP input size # reshape token array to CLIP input size
batched_tokens = [] batched_tokens = []
batch = [] batch = []
if self.start_token is not None: if self.start_token is not None:
batch.append((self.start_token, 1.0, 0)) batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch) batched_tokens.append(batch)
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch # determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0: while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1: if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1 remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token # break word in two and add end token
if is_large: if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:] t_group = t_group[remaining_length:]
#add end token and pad # add end token and pad
else: else:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch # start new batch
batch = [] batch = []
if self.start_token is not None: if self.start_token is not None:
batch.append((self.start_token, 1.0, 0)) batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch) batched_tokens.append(batch)
else: else:
batch.extend([(t,w,i+1) for t,w in t_group]) batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = [] t_group = []
#fill last batch # fill last batch
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
@ -487,11 +528,10 @@ class SDTokenizer:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens return batched_tokens
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
@ -502,21 +542,25 @@ class SD1Tokenizer:
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {} out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) out[self.clip_name] = self.sd_tokenizer.tokenize_with_weights(text, return_word_ids)
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair) return self.sd_tokenizer.untokenize(token_weight_pair)
@property
def sd_tokenizer(self) -> SDTokenizer:
return getattr(self, self.clip)
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
super().__init__() super().__init__()
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
def set_clip_options(self, options): def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options) getattr(self, self.clip).set_clip_options(options)

View File

@ -1,27 +1,28 @@
from pkg_resources import resource_filename
from . import sd1_clip from . import sd1_clip
import os
from .sd1_clip import get_clip_config_dict
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer = "hidden"
layer_idx=-2 layer_idx = -2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer): class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel): class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs): def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, textmodel_json_config=textmodel_json_config, **kwargs)

View File

@ -1,20 +1,23 @@
from . import sd1_clip
import torch import torch
import os
from . import sd1_clip
from .sd1_clip import get_clip_config_dict
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer = "hidden"
layer_idx=-2 layer_idx = -2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SDTokenizer): class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
@ -25,7 +28,7 @@ class SDXLTokenizer:
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {} out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
@ -34,6 +37,7 @@ class SDXLTokenizer:
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair) return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
@ -61,28 +65,32 @@ class SDXLClipModel(torch.nn.Module):
else: else:
return self.clip_l.load_sd(sd) return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, textmodel_json_config=textmodel_json_config)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None): def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config)

View File

@ -1,11 +1,15 @@
from __future__ import annotations
import contextlib
import logging import logging
import math import math
import os.path import os.path
import random
import struct import struct
import sys import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional from typing import Optional, Any
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -14,6 +18,7 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from . import checkpoint_pickle, interruption from . import checkpoint_pickle, interruption
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context from .execution_context import current_execution_context
@ -30,6 +35,7 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
@ -498,8 +504,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return output return output
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None): def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
server = current_execution_context().server server = server or current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context # todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id client_id = client_id or server.client_id
interruption.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
@ -570,15 +576,14 @@ def comfy_tqdm():
""" """
_original_init = tqdm.__init__ _original_init = tqdm.__init__
_original_update = tqdm.update _original_update = tqdm.update
server = current_execution_context().server
try: try:
def __init(self, *args, **kwargs): def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs) _original_init(self, *args, **kwargs)
self._progress_bar = ProgressBar(self.total)
def __update(self, n=1): def __update(self, n=1):
assert self._progress_bar is not None
_original_update(self, n) _original_update(self, n)
self._progress_bar.update(n) _progress_bar_update(n, self.total, server=server)
tqdm.__init__ = __init tqdm.__init__ = __init
tqdm.update = __update tqdm.update = __update
@ -596,3 +601,30 @@ def comfy_progress(total: float) -> ProgressBar:
yield ProgressBar(total) yield ProgressBar(total)
else: else:
yield _DisabledProgressBar() yield _DisabledProgressBar()
@contextlib.contextmanager
def seed_for_block(seed):
# Save the current random state
torch_rng_state = torch.get_rng_state()
random_state = random.getstate()
numpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
# Set the new seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
yield
finally:
# Restore the previous random state
torch.set_rng_state(torch_rng_state)
random.setstate(random_state)
np.random.set_state(numpy_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(cuda_rng_state)

View File

@ -0,0 +1,59 @@
/**
* Uses code adapted from https://github.com/Zuellni/ComfyUI-ExLlama-Nodes
*
* MIT License
*
* Copyright (c) 2023 Zuellni
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
import { app } from "../../scripts/app.js";
import { ComfyWidgets } from "../../scripts/widgets.js";
app.registerExtension({
name: "Comfy.StringNodes",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "PreviewString" || nodeData.name === "SaveString") {
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function ({ string }) {
onExecuted?.apply(this, arguments);
if (this.widgets) {
const index = this.widgets.findIndex((w) => w.name === "output");
if (index !== -1) {
for (let i = index; i < this.widgets.length; i++) {
this.widgets[i].onRemove?.();
}
this.widgets.length = index;
}
const options = ["STRING", { multiline: true }];
const widget = ComfyWidgets["STRING"](this, "output", options, app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.opacity = 0.7;
widget.value = string;
}
};
}
},
});

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Any, List, Dict
import torch
from fastchat.model import get_conversation_template
from transformers import AutoModelForCausalLM, AutoTokenizer
from comfy.language.language_types import ProcArgsRes
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_tqdm, seed_for_block
_transformer_args_deterministic_decoding = {
"max_length": ("INT", {"default": 4096, "min": 1}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
}
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
seed = generate_kwargs.pop("seed", 0)
return ProcArgsRes(seed, generate_kwargs)
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),)
}
}
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str):
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
return model_managed,
class SimpleBatchDecode(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
tokenizer = model.tokenizer
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
with comfy_tqdm():
with seed_for_block(seed):
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
class SimpleInstruct(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
conv = get_conversation_template(model.repo_id)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = model.tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
with seed_for_block(seed):
output_ids = model.model.generate(
**inputs,
do_sample=True,
**generate_kwargs
)
if model.model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
outputs = model.tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
return outputs,
class PreviewString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
}
}
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
def execute(self, value: str):
return {"ui": {"string": [value]}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformersLoader,
SimpleBatchDecode,
SimpleInstruct,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -0,0 +1,142 @@
"""
Adapted from https://github.com/microsoft/unilm/blob/master/textdiffuser-2/inference_textdiffuser2_t2i_full.py#L334
The MIT License (MIT)
Copyright (c) Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import string
from typing import Optional
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.sd import CLIP
from comfy.sd1_clip import SDTokenizer
class TextDiffuserTokens(CustomNode):
ALPHABET = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(alphabet) = 95
TOKENS = []
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"clip": ("CLIP",)
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "execute"
def execute(self, clip: CLIP):
if len(TextDiffuserTokens.TOKENS) == 0:
for i in range(520):
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')
TextDiffuserTokens.TOKENS.append(f't{i}</w>')
TextDiffuserTokens.TOKENS.append(f'r{i}</w>')
TextDiffuserTokens.TOKENS.append(f'b{i}</w>')
for c in TextDiffuserTokens.ALPHABET:
TextDiffuserTokens.TOKENS.append(f'[{c}]</w>')
tokenizer: SDTokenizer = clip.tokenizer.sd_tokenizer
existing_vocab = frozenset(tokenizer.tokenizer.get_vocab().keys())
tokens = [t for t in TextDiffuserTokens.TOKENS if t not in existing_vocab]
if len(tokens) != 0:
tokenizer.add_tokens(tokens)
# todo: assert that the clip's vocab size is what we expect
return clip,
class TextDiffuserPrepare(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"prompt": ("STRING", {"default": "", "multiline": True}),
},
"optional": {
"text": ("STRING", {"default": "", "multiline": True})
}
}
FUNCTION = "execute"
RETURN_TYPES = "STRING",
RETURN_NAMES = "INSTRUCT STRING",
def execute(self, prompt: str, text: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult:
keywords = text.split("\n")
if len(keywords) > 0:
# text diffusers does indeed format keywords as
# ['some', 'word']
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}'
else:
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}'
return message,
class TextDiffuserDecodeLayout(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"layout_model": ("MODEL", {}),
"clip": ("CLIP", {}),
"prompt": ("STRING", {}),
"instruct_response": ("STRING", {})
}
}
FUNCTION = "execute"
RETURN_TYPES = "STRING",
RETURN_NAMES = "CLIP STRING",
def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str, *args, **kwargs) -> ValidatedNodeResult:
current_ocr = instruct_response.split('\n')
words = [clip.tokenizer.sd_tokenizer.tokenizer.eos_token, clip.tokenizer.sd_tokenizer.tokenizer.bos_token]
for ocr in current_ocr:
ocr = ocr.strip()
# .com ??
if len(ocr) == 0 or '###' in ocr or '.com' in ocr:
continue
items = ocr.split()
pred = ' '.join(items[:-1])
box = items[-1]
l, t, r, b = map(int, box.split(','))
words.extend([f'l{l}', f't{t}', f'r{r}', f'b{b}'])
char_list = [f'[{i}]' for i in pred]
words.extend(char_list)
words.append(clip.tokenizer.sd_tokenizer.tokenizer.eos_token)
return prompt + ' ' + ' '.join(words),
NODE_CLASS_MAPPINGS = {}
for cls in (
TextDiffuserDecodeLayout,
TextDiffuserPrepare,
TextDiffuserTokens,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -5,7 +5,11 @@ torchsde>=0.2.6
einops>=0.6.0 einops>=0.6.0
open-clip-torch>=2.16.0 open-clip-torch>=2.16.0
transformers>=4.29.1 transformers>=4.29.1
peft
torchinfo
fschat[model_worker]
safetensors>=0.3.0 safetensors>=0.3.0
bitsandbytes
pytorch-lightning>=2.0.0 pytorch-lightning>=2.0.0
aiohttp>=3.8.4 aiohttp>=3.8.4
accelerate>=0.25.0 accelerate>=0.25.0