mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
LLM support in ComfyUI
- Currently uses `transformers` - Supports model management and correctly loading and unloading models based on what your machine can support - Includes a Text Diffusers 2 workflow to demonstrate text rendering in SD1.5
This commit is contained in:
parent
0ee2f3bf15
commit
8741cb3ce8
@ -87,9 +87,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
vocab_size = config_dict["vocab_size"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from . import sd, utils
|
||||
|
||||
def first_file(path, filenames):
|
||||
|
||||
def first_file(path, filenames) -> str | None:
|
||||
for f in filenames:
|
||||
p = os.path.join(path, f)
|
||||
if os.path.exists(p):
|
||||
return p
|
||||
return str(p)
|
||||
return None
|
||||
|
||||
|
||||
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
||||
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
||||
@ -22,15 +26,17 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = sd.load_unet(unet_path)
|
||||
if unet_path is not None:
|
||||
unet = sd.load_unet(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
|
||||
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
|
||||
if output_clip and not all(te is None for te in text_encoder_paths):
|
||||
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config1)
|
||||
|
||||
vae = None
|
||||
if output_vae:
|
||||
if output_vae and vae_path is not None:
|
||||
_sd = utils.load_torch_file(vae_path)
|
||||
vae = sd.VAE(sd=_sd)
|
||||
|
||||
return (unet, clip, vae)
|
||||
return unet, clip, vae
|
||||
|
||||
5
comfy/language/__init__.py
Normal file
5
comfy/language/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from fastchat.model.model_adapter import register_model_adapter
|
||||
|
||||
from .fastchat_adapters import Phi3Adapter
|
||||
|
||||
register_model_adapter(Phi3Adapter)
|
||||
62
comfy/language/fastchat_adapters.py
Normal file
62
comfy/language/fastchat_adapters.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class Phi3Adapter(BaseModelAdapter):
|
||||
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "phi-3-mini-128k-instruct" in model_path.lower()
|
||||
|
||||
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
self.model = model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
**from_pretrained_kwargs,
|
||||
)
|
||||
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
|
||||
if input:
|
||||
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
|
||||
else:
|
||||
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
|
||||
return prompt
|
||||
|
||||
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
|
||||
prompt = self.generate_prompt(messages[-1]["content"])
|
||||
|
||||
for i in range(len(messages) - 2, -1, -1):
|
||||
if messages[i]["role"] == "user":
|
||||
prompt = self.generate_prompt(messages[i]["content"]) + prompt
|
||||
elif messages[i]["role"] == "assistant":
|
||||
prompt = messages[i]["content"] + prompt
|
||||
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
output_ids = self.model.generate(
|
||||
input_ids,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output = output.replace(prompt, "").strip()
|
||||
|
||||
return output
|
||||
|
||||
def get_default_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("phi-3-mini")
|
||||
8
comfy/language/language_types.py
Normal file
8
comfy/language/language_types.py
Normal file
@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple, Dict, Any
|
||||
|
||||
|
||||
class ProcArgsRes(NamedTuple):
|
||||
seed: int
|
||||
generate_kwargs: Dict[str, Any]
|
||||
70
comfy/language/transformers_model_management.py
Normal file
70
comfy/language/transformers_model_management.py
Normal file
@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from ..model_management import unet_offload_device, get_torch_device
|
||||
from ..model_management_types import ModelManageable
|
||||
|
||||
|
||||
class TransformersManagedModel(ModelManageable):
|
||||
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None):
|
||||
self.repo_id = repo_id
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
||||
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
||||
self.load_device = get_torch_device()
|
||||
self.offload_device = unet_offload_device()
|
||||
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: PreTrainedModel
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
return hasattr(other, "model") and self.model is other.model
|
||||
|
||||
def clone_has_same_weights(self, clone: Any) -> bool:
|
||||
if not isinstance(clone, TransformersManagedModel):
|
||||
return False
|
||||
|
||||
clone: TransformersManagedModel
|
||||
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
|
||||
|
||||
def model_size(self) -> int:
|
||||
return self._size
|
||||
|
||||
def model_patches_to(self, arg: torch.device | torch.dtype):
|
||||
if isinstance(arg, torch.device):
|
||||
self.model.to(device=arg)
|
||||
else:
|
||||
self.model.to(arg)
|
||||
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
return self.model.to(device=offload_device)
|
||||
@ -7,7 +7,7 @@ from os.path import join
|
||||
from typing import List, Any, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
@ -167,6 +167,7 @@ KNOWN_CHECKPOINTS = [
|
||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
||||
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
||||
|
||||
]
|
||||
|
||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||
@ -297,6 +298,12 @@ KNOWN_VAES = [
|
||||
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_HUGGINGFACE_MODEL_REPOS = {
|
||||
"JingyeChen22/textdiffuser2_layout_planner",
|
||||
'JingyeChen22/textdiffuser2-full-ft',
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
|
||||
|
||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||
if args.disable_known_models:
|
||||
@ -304,3 +311,10 @@ def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile
|
||||
symbol += models
|
||||
folder_paths.invalidate_cache(folder_name)
|
||||
return symbol
|
||||
|
||||
|
||||
def huggingface_repos() -> List[str]:
|
||||
cache_info = scan_cache_dir()
|
||||
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
|
||||
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
|
||||
return list(existing_repo_ids | known_repo_ids)
|
||||
|
||||
@ -1,34 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
from threading import RLock
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from .cli_args import args
|
||||
from . import interruption
|
||||
from threading import RLock
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from . import interruption
|
||||
from .cli_args import args
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
model_management_lock = RLock()
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
DISABLED = 0 # No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 # Very low vram: enable all the options to save vram
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
|
||||
SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
|
||||
|
||||
|
||||
class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
@ -46,6 +48,7 @@ if args.deterministic:
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
@ -54,10 +57,11 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
except:
|
||||
@ -73,6 +77,7 @@ except:
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
|
||||
def is_intel_xpu():
|
||||
global cpu_state
|
||||
global xpu_available
|
||||
@ -81,6 +86,7 @@ def is_intel_xpu():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@ -97,6 +103,7 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -107,7 +114,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
@ -126,6 +133,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
@ -147,6 +155,7 @@ else:
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
try:
|
||||
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
|
||||
@ -164,6 +173,7 @@ else:
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
|
||||
def is_nvidia():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
@ -171,6 +181,7 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -205,7 +216,6 @@ elif args.bf16_vae:
|
||||
elif args.fp32_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
@ -233,7 +243,6 @@ if lowvram_available:
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
|
||||
|
||||
if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
@ -247,6 +256,7 @@ DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||
if DISABLE_SMART_MEMORY:
|
||||
logging.info("Disabling smart memory management")
|
||||
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
if device.type == "cuda":
|
||||
@ -262,6 +272,7 @@ def get_torch_device_name(device):
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
|
||||
try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
@ -271,6 +282,7 @@ logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@ -279,6 +291,7 @@ def module_size(module):
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
return module_mem
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelManageable):
|
||||
self.model = model
|
||||
@ -328,9 +341,11 @@ class LoadedModel:
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||
with model_management_lock:
|
||||
to_unload = []
|
||||
@ -361,12 +376,13 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
|
||||
|
||||
return unload_weight
|
||||
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
with model_management_lock:
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
@ -391,6 +407,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
global vram_state
|
||||
|
||||
@ -424,7 +441,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different
|
||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
@ -432,7 +449,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
|
||||
if weights_unloaded is not None:
|
||||
loaded_model.weights_loaded = not weights_unloaded
|
||||
|
||||
@ -447,8 +464,8 @@ def load_models_gpu(models, memory_required=0):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
|
||||
if model_size > (current_free_mem - inference_memory): # only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
else:
|
||||
lowvram_model_memory = 0
|
||||
@ -465,6 +482,7 @@ def load_model_gpu(model):
|
||||
with model_management_lock:
|
||||
return load_models_gpu([model])
|
||||
|
||||
|
||||
def cleanup_models(keep_clone_weights_loaded=False):
|
||||
with model_management_lock:
|
||||
to_delete = []
|
||||
@ -472,8 +490,8 @@ def cleanup_models(keep_clone_weights_loaded=False):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
#TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
||||
# TODO: find a less fragile way to do this.
|
||||
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
@ -481,6 +499,7 @@ def cleanup_models(keep_clone_weights_loaded=False):
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
|
||||
def dtype_size(dtype):
|
||||
dtype_size = 4
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
@ -490,17 +509,19 @@ def dtype_size(dtype):
|
||||
else:
|
||||
try:
|
||||
dtype_size = dtype.itemsize
|
||||
except: #Old pytorch doesn't have .itemsize
|
||||
except: # Old pytorch doesn't have .itemsize
|
||||
pass
|
||||
return dtype_size
|
||||
|
||||
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
def unet_initial_load_device(parameters, dtype):
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return torch_dev
|
||||
@ -518,7 +539,8 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
@ -535,8 +557,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
@ -556,12 +579,14 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@ -573,6 +598,7 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
@ -595,27 +621,32 @@ def intermediate_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
return get_torch_device()
|
||||
|
||||
|
||||
def vae_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def vae_dtype():
|
||||
global VAE_DTYPE
|
||||
return VAE_DTYPE
|
||||
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def supports_dtype(device, dtype): #TODO
|
||||
|
||||
def supports_dtype(device, dtype): # TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if is_device_cpu(device):
|
||||
@ -626,12 +657,14 @@ def supports_dtype(device, dtype): #TODO
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
return False
|
||||
# return True #TODO: figure out why this causes issues
|
||||
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
with model_management_lock:
|
||||
device_supports_cast = False
|
||||
@ -655,6 +688,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
@ -674,18 +708,21 @@ def xformers_enabled_vae():
|
||||
|
||||
return XFORMERS_ENABLED_VAE
|
||||
|
||||
|
||||
def pytorch_attention_enabled():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
|
||||
def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||
# TODO: more reliable way of checking for flash attention?
|
||||
if is_nvidia(): # pytorch flash attention only works on Nvidia
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -696,7 +733,7 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_free_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
@ -718,29 +755,36 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
else:
|
||||
return mem_free_total
|
||||
|
||||
|
||||
def cpu_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
||||
|
||||
def mps_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
|
||||
def is_device_type(device, type):
|
||||
if hasattr(device, 'type'):
|
||||
if (device.type == type):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_device_cpu(device):
|
||||
return is_device_type(device, 'cpu')
|
||||
|
||||
|
||||
def is_device_mps(device):
|
||||
return is_device_type(device, 'mps')
|
||||
|
||||
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
global directml_enabled
|
||||
|
||||
@ -781,9 +825,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
#when the model doesn't actually fit on the card
|
||||
#TODO: actually test if GP106 and others have the same type of behavior
|
||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
# when the model doesn't actually fit on the card
|
||||
# TODO: actually test if GP106 and others have the same type of behavior
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
@ -797,7 +841,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props.major < 7:
|
||||
return False
|
||||
|
||||
#FP16 is just broken on these cards
|
||||
# FP16 is just broken on these cards
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
|
||||
for x in nvidia_16_series:
|
||||
if x in props.name:
|
||||
@ -805,12 +849,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if device is not None:
|
||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None: #TODO not sure about mps bf16 support
|
||||
if device is not None: # TODO not sure about mps bf16 support
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
|
||||
@ -842,6 +887,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
with model_management_lock:
|
||||
global cpu_state
|
||||
@ -850,16 +896,17 @@ def soft_empty_cache(force=False):
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def unload_all_models():
|
||||
with model_management_lock:
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, Optional
|
||||
from typing import Protocol, Optional, Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,13 +18,12 @@ class ModelManageable(Protocol):
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: torch.nn.Module
|
||||
current_device: torch.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
def current_device(self) -> torch.device:
|
||||
...
|
||||
|
||||
def is_clone(self, other: torch.nn.Module) -> bool:
|
||||
def is_clone(self, other: Any) -> bool:
|
||||
...
|
||||
|
||||
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from . import utils
|
||||
import torch
|
||||
|
||||
from . import model_management
|
||||
from . import utils
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
|
||||
@ -20,6 +21,7 @@ def apply_weight_decompose(dora_scale, weight):
|
||||
|
||||
return weight * (dora_scale / weight_norm)
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
@ -41,6 +43,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -49,14 +52,15 @@ class ModelPatcher(ModelManageable):
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_options = {"transformer_options": {}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self._current_device: torch.device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
self._current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
self._current_device = current_device
|
||||
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
@ -71,7 +75,7 @@ class ModelPatcher(ModelManageable):
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@ -107,7 +111,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
@ -270,18 +274,20 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
self._current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
@ -325,7 +331,7 @@ class ModelPatcher(ModelManageable):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
@ -340,14 +346,14 @@ class ModelPatcher(ModelManageable):
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif patch_type == "lora": #lora/locon
|
||||
elif patch_type == "lora": # lora/locon
|
||||
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
@ -407,7 +413,7 @@ class ModelPatcher(ModelManageable):
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
if v[5] is not None: # cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
@ -478,10 +484,14 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
self._current_device = value = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
@property
|
||||
def current_device(self) -> torch.device:
|
||||
return self._current_device
|
||||
|
||||
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from PIL import Image, ImageOps, ImageSequence, ImageFile
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from natsort import natsorted
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -25,11 +26,13 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
from .. import node_helpers
|
||||
from ..utils import comfy_tqdm
|
||||
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
@ -513,11 +516,14 @@ class DiffusersLoader:
|
||||
if "model_index.json" in files:
|
||||
paths.append(os.path.relpath(root, start=search_path))
|
||||
|
||||
paths += huggingface_repos()
|
||||
paths = list(frozenset(paths))
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders/deprecated"
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
@ -526,6 +532,9 @@ class DiffusersLoader:
|
||||
if os.path.exists(path):
|
||||
model_path = path
|
||||
break
|
||||
if not os.path.exists(model_path):
|
||||
with comfy_tqdm():
|
||||
model_path = snapshot_download(model_path)
|
||||
|
||||
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
@ -1029,7 +1038,7 @@ class LatentFromBatch:
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1044,7 +1053,7 @@ class RepeatLatentBatch:
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
@ -1374,7 +1383,7 @@ class SaveImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
@ -1450,9 +1459,9 @@ class LoadImage:
|
||||
|
||||
def load_image(self, image: str):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
|
||||
|
||||
136
comfy/sd.py
136
comfy/sd.py
@ -1,31 +1,32 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from . import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from . import utils
|
||||
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
from . import diffusers_convert
|
||||
from . import gligen
|
||||
from . import lora
|
||||
from . import model_detection
|
||||
|
||||
from . import model_management
|
||||
from . import model_patcher
|
||||
from . import model_sampling
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
|
||||
from . import model_patcher
|
||||
from . import model_sampling
|
||||
from . import lora
|
||||
from . import utils
|
||||
from .ldm.cascade.stage_a import StageA
|
||||
from .ldm.cascade.stage_c_coder import StageC_coder
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .t2i_adapter import adapter
|
||||
from .taesd import taesd
|
||||
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
m = set(m)
|
||||
@ -40,6 +41,7 @@ def load_model_weights(model, sd):
|
||||
logging.warning("missing {}".format(m))
|
||||
return model
|
||||
|
||||
|
||||
def load_clip_weights(model, sd):
|
||||
k = list(sd.keys())
|
||||
for x in k:
|
||||
@ -87,7 +89,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False):
|
||||
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@ -98,10 +100,12 @@ class CLIP:
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
||||
params['textmodel_json_config'] = textmodel_json_config
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
|
||||
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.layer_idx = None
|
||||
|
||||
@ -157,12 +161,13 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) # These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.upscale_ratio = 8
|
||||
@ -181,16 +186,16 @@ class VAE:
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = taesd.TAESD()
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
elif "vquantizer.codebook.weight" in sd: # VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
#TODO
|
||||
#self.memory_used_encode
|
||||
#self.memory_used_decode
|
||||
# TODO
|
||||
# self.memory_used_encode
|
||||
# self.memory_used_decode
|
||||
self.process_input = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
|
||||
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: # effnet: encoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
@ -198,22 +203,22 @@ class VAE:
|
||||
for k in sd:
|
||||
new_sd["encoder.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
|
||||
elif "blocks.11.num_batches_tracked" in sd: # previewer: decoder for stage c latent of stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.latent_channels = 16
|
||||
new_sd = {}
|
||||
for k in sd:
|
||||
new_sd["previewer.{}".format(k)] = sd[k]
|
||||
sd = new_sd
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
|
||||
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: # combined effnet and previewer for stable cascade
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
# default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: # Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
self.upscale_ratio = 4
|
||||
@ -261,7 +266,7 @@ class VAE:
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
@ -269,22 +274,22 @@ class VAE:
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
@ -298,23 +303,23 @@ class VAE:
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1,-1)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@ -323,8 +328,8 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@ -332,16 +337,17 @@ class VAE:
|
||||
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
self.model = model
|
||||
@ -360,26 +366,33 @@ def load_style_model(ckpt_path):
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CLIPTarget:
|
||||
clip: Optional[Any] = None
|
||||
vae: Optional[Any] = None
|
||||
params: Optional[dict] = dataclasses.field(default_factory=dict)
|
||||
tokenizer: Optional[Any] = None
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
for i in range(len(clip_data)):
|
||||
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||
clip_data[i] = utils.clip_text_transformers_convert(clip_data[i], "", "")
|
||||
else:
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) # old models saved with the CLIPSave node
|
||||
|
||||
clip_target = EmptyClass()
|
||||
clip_target = CLIPTarget()
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
@ -399,7 +412,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@ -409,6 +422,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
return clip
|
||||
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
model = gligen.load_gligen(data)
|
||||
@ -416,10 +430,11 @@ def load_gligen(ckpt_path):
|
||||
model = model.half()
|
||||
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
|
||||
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
|
||||
#TODO: this function is a mess and should be removed eventually
|
||||
# TODO: this function is a mess and should be removed eventually
|
||||
if config is None:
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
@ -430,8 +445,10 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
m = model.clone()
|
||||
|
||||
class ModelSamplingAdvanced(model_sampling.ModelSamplingDiscrete, model_sampling.V_PREDICTION):
|
||||
pass
|
||||
|
||||
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
|
||||
model = m
|
||||
|
||||
@ -441,6 +458,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
sd_keys = sd.keys()
|
||||
@ -467,7 +485,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
inital_load_device = model_management.unet_initial_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
@ -509,18 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
return (_model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
def load_unet_state_dict(sd): # load unet in diffusers format
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
if model_config is None:
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
else: # diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||
if model_config is None:
|
||||
return None
|
||||
@ -546,6 +564,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
def load_unet(unet_path):
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd)
|
||||
@ -554,6 +573,7 @@ def load_unet(unet_path):
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
|
||||
@ -1,15 +1,45 @@
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
from transformers import CLIPTokenizer
|
||||
from . import ops
|
||||
import torch
|
||||
import traceback
|
||||
import zipfile
|
||||
from . import model_management
|
||||
from pkg_resources import resource_filename
|
||||
from . import clip_model
|
||||
import importlib.resources as resources
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import zipfile
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pkg_resources import resource_filename
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
from . import ops
|
||||
|
||||
|
||||
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
|
||||
config: dict | None = None
|
||||
|
||||
if text_model_config_or_path is None:
|
||||
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
|
||||
|
||||
if isinstance(text_model_config_or_path, str):
|
||||
if text_model_config_or_path.startswith("{"):
|
||||
config = json.loads(text_model_config_or_path)
|
||||
else:
|
||||
if not os.path.exists(text_model_config_or_path):
|
||||
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
with open(text_model_config_or_path) as f:
|
||||
config = json.load(f)
|
||||
elif isinstance(text_model_config_or_path, dict):
|
||||
config = text_model_config_or_path
|
||||
|
||||
assert config is not None
|
||||
return config
|
||||
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@ -23,6 +53,7 @@ def gen_empty_tokens(special_tokens, length):
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
@ -46,7 +77,7 @@ class ClipTokenWeightEncoder:
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
@ -60,6 +91,7 @@ class ClipTokenWeightEncoder:
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
@ -67,20 +99,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
if not os.path.exists(textmodel_json_config):
|
||||
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
|
||||
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@ -105,7 +133,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@ -132,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens_temp = []
|
||||
for y in x:
|
||||
if isinstance(y, int):
|
||||
if y == token_dict_size: #EOS token
|
||||
if y == token_dict_size: # EOS token
|
||||
y = -1
|
||||
tokens_temp += [y]
|
||||
else:
|
||||
@ -153,12 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
||||
new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
|
||||
processed_tokens = []
|
||||
for x in out_tokens:
|
||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
|
||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one
|
||||
|
||||
return processed_tokens
|
||||
|
||||
@ -201,6 +229,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
@ -229,6 +258,7 @@ def parse_parentheses(string):
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
@ -240,7 +270,7 @@ def token_weights(string, current_weight):
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx+1:])
|
||||
weight = float(x[xx + 1:])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
@ -249,16 +279,19 @@ def token_weights(string, current_weight):
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
|
||||
def safe_load_embed_zip(embed_path):
|
||||
with zipfile.ZipFile(embed_path) as myzip:
|
||||
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
|
||||
@ -267,17 +300,18 @@ def safe_load_embed_zip(embed_path):
|
||||
with myzip.open(n) as myfile:
|
||||
data = myfile.read()
|
||||
number = len(data) // 4
|
||||
length_embed = 1024 #sd2.x
|
||||
length_embed = 1024 # sd2.x
|
||||
if number < 768:
|
||||
continue
|
||||
if number % 768 == 0:
|
||||
length_embed = 768 #sd1.x
|
||||
length_embed = 768 # sd1.x
|
||||
num_embeds = number // length_embed
|
||||
embed = torch.frombuffer(data, dtype=torch.float)
|
||||
out = embed.reshape((num_embeds, length_embed)).clone()
|
||||
del embed
|
||||
return out
|
||||
|
||||
|
||||
def expand_directory_list(directories):
|
||||
dirs = set()
|
||||
for x in directories:
|
||||
@ -286,6 +320,7 @@ def expand_directory_list(directories):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
|
||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
@ -356,6 +391,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
|
||||
if tokenizer_path is None:
|
||||
@ -378,16 +414,20 @@ class SDTokenizer:
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.add_tokens([])
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
self.embedding_identifier = "embedding:"
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
def _try_get_embedding(self, embedding_name:str):
|
||||
def add_tokens(self, tokens: List[str]):
|
||||
if len(tokens) > 0:
|
||||
self.tokenizer.add_tokens(tokens)
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
|
||||
def _try_get_embedding(self, embedding_name: str):
|
||||
'''
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
@ -400,8 +440,7 @@ class SDTokenizer:
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
'''
|
||||
Takes a prompt and converts it to a list of (token, weight, word id) elements.
|
||||
Tokens can both be integer tokens and pre computed CLIP tensors.
|
||||
@ -417,13 +456,13 @@ class SDTokenizer:
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
#tokenize words
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
#if we find an embedding, deal with the embedding
|
||||
# if we find an embedding, deal with the embedding
|
||||
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
@ -434,52 +473,54 @@ class SDTokenizer:
|
||||
tokens.append([(embed, weight)])
|
||||
else:
|
||||
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
|
||||
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
||||
# if we accidentally have leftover text, continue parsing using leftover, else move on to next word
|
||||
if leftover != "":
|
||||
word = leftover
|
||||
else:
|
||||
continue
|
||||
#parse word
|
||||
# parse word
|
||||
exact_word = f"{word}</w>"
|
||||
if exact_word in vocab:
|
||||
if word == self.tokenizer.eos_token:
|
||||
tokenizer_result = [self.tokenizer.eos_token_id]
|
||||
elif exact_word in vocab:
|
||||
tokenizer_result = [vocab[exact_word]]
|
||||
else:
|
||||
tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:-1]
|
||||
tokens.append([(t, weight) for t in tokenizer_result])
|
||||
|
||||
#reshape token array to CLIP input size
|
||||
# reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
#break word in two and add end token
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
#add end token and pad
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||
t_group = []
|
||||
|
||||
#fill last batch
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
@ -487,11 +528,10 @@ class SDTokenizer:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
@ -502,21 +542,25 @@ class SD1Tokenizer:
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
out[self.clip_name] = self.sd_tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
return self.sd_tokenizer.untokenize(token_weight_pair)
|
||||
|
||||
@property
|
||||
def sd_tokenizer(self) -> SDTokenizer:
|
||||
return getattr(self, self.clip)
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
|
||||
super().__init__()
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
getattr(self, self.clip).set_clip_options(options)
|
||||
|
||||
@ -1,27 +1,28 @@
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from . import sd1_clip
|
||||
import os
|
||||
|
||||
from .sd1_clip import get_clip_config_dict
|
||||
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
if not os.path.exists(textmodel_json_config):
|
||||
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
@ -1,20 +1,23 @@
|
||||
from . import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
from . import sd1_clip
|
||||
from .sd1_clip import get_clip_config_dict
|
||||
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
@ -25,7 +28,7 @@ class SDXLTokenizer:
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
@ -34,6 +37,7 @@ class SDXLTokenizer:
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_g.untokenize(token_weight_pair)
|
||||
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
@ -61,28 +65,32 @@ class SDXLClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, textmodel_json_config=textmodel_json_config)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
|
||||
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config)
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -14,6 +18,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
from .execution_context import current_execution_context
|
||||
|
||||
@ -30,6 +35,7 @@ def _get_progress_bar_enabled():
|
||||
|
||||
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
@ -498,8 +504,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
|
||||
return output
|
||||
|
||||
|
||||
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None):
|
||||
server = current_execution_context().server
|
||||
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
||||
server = server or current_execution_context().server
|
||||
# todo: this should really be from the context. right now the server is behaving like a context
|
||||
client_id = client_id or server.client_id
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
@ -570,15 +576,14 @@ def comfy_tqdm():
|
||||
"""
|
||||
_original_init = tqdm.__init__
|
||||
_original_update = tqdm.update
|
||||
server = current_execution_context().server
|
||||
try:
|
||||
def __init(self, *args, **kwargs):
|
||||
_original_init(self, *args, **kwargs)
|
||||
self._progress_bar = ProgressBar(self.total)
|
||||
|
||||
def __update(self, n=1):
|
||||
assert self._progress_bar is not None
|
||||
_original_update(self, n)
|
||||
self._progress_bar.update(n)
|
||||
_progress_bar_update(n, self.total, server=server)
|
||||
|
||||
tqdm.__init__ = __init
|
||||
tqdm.update = __update
|
||||
@ -596,3 +601,30 @@ def comfy_progress(total: float) -> ProgressBar:
|
||||
yield ProgressBar(total)
|
||||
else:
|
||||
yield _DisabledProgressBar()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def seed_for_block(seed):
|
||||
# Save the current random state
|
||||
torch_rng_state = torch.get_rng_state()
|
||||
random_state = random.getstate()
|
||||
numpy_rng_state = np.random.get_state()
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||
|
||||
# Set the new seed
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore the previous random state
|
||||
torch.set_rng_state(torch_rng_state)
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(numpy_rng_state)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||
|
||||
59
comfy/web/extensions/core/textExtraOutput.js
Normal file
59
comfy/web/extensions/core/textExtraOutput.js
Normal file
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Uses code adapted from https://github.com/Zuellni/ComfyUI-ExLlama-Nodes
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2023 Zuellni
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*/
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { ComfyWidgets } from "../../scripts/widgets.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.StringNodes",
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeData.name === "PreviewString" || nodeData.name === "SaveString") {
|
||||
const onExecuted = nodeType.prototype.onExecuted;
|
||||
|
||||
nodeType.prototype.onExecuted = function ({ string }) {
|
||||
onExecuted?.apply(this, arguments);
|
||||
|
||||
if (this.widgets) {
|
||||
const index = this.widgets.findIndex((w) => w.name === "output");
|
||||
|
||||
if (index !== -1) {
|
||||
for (let i = index; i < this.widgets.length; i++) {
|
||||
this.widgets[i].onRemove?.();
|
||||
}
|
||||
|
||||
this.widgets.length = index;
|
||||
}
|
||||
|
||||
const options = ["STRING", { multiline: true }];
|
||||
const widget = ComfyWidgets["STRING"](this, "output", options, app).widget;
|
||||
|
||||
widget.inputEl.readOnly = true;
|
||||
widget.inputEl.style.opacity = 0.7;
|
||||
widget.value = string;
|
||||
}
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
139
comfy_extras/nodes/nodes_language.py
Normal file
139
comfy_extras/nodes/nodes_language.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import torch
|
||||
from fastchat.model import get_conversation_template
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from comfy.language.language_types import ProcArgsRes
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import huggingface_repos
|
||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy.utils import comfy_tqdm, seed_for_block
|
||||
|
||||
_transformer_args_deterministic_decoding = {
|
||||
"max_length": ("INT", {"default": 4096, "min": 1}),
|
||||
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
|
||||
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
|
||||
}
|
||||
|
||||
|
||||
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
|
||||
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
|
||||
seed = generate_kwargs.pop("seed", 0)
|
||||
return ProcArgsRes(seed, generate_kwargs)
|
||||
|
||||
|
||||
class TransformersLoader(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (huggingface_repos(),)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = "MODEL",
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, ckpt_name: str):
|
||||
with comfy_tqdm():
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
|
||||
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
|
||||
return model_managed,
|
||||
|
||||
|
||||
class SimpleBatchDecode(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
**_transformer_args_deterministic_decoding
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
||||
load_model_gpu(model)
|
||||
seed, generate_kwargs = proc_args(kwargs)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
|
||||
with comfy_tqdm():
|
||||
with seed_for_block(seed):
|
||||
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
|
||||
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return outputs,
|
||||
|
||||
|
||||
class SimpleInstruct(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
**_transformer_args_deterministic_decoding
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
||||
load_model_gpu(model)
|
||||
seed, generate_kwargs = proc_args(kwargs)
|
||||
conv = get_conversation_template(model.repo_id)
|
||||
conv.append_message(conv.roles[0], prompt)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
inputs = model.tokenizer([prompt], return_token_type_ids=False)
|
||||
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
|
||||
with seed_for_block(seed):
|
||||
output_ids = model.model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
**generate_kwargs
|
||||
)
|
||||
if model.model.config.is_encoder_decoder:
|
||||
output_ids = output_ids[0]
|
||||
else:
|
||||
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
|
||||
outputs = model.tokenizer.decode(
|
||||
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
)
|
||||
return outputs,
|
||||
|
||||
|
||||
class PreviewString(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {}),
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = ("STRING",)
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def execute(self, value: str):
|
||||
return {"ui": {"string": [value]}}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
TransformersLoader,
|
||||
SimpleBatchDecode,
|
||||
SimpleInstruct,
|
||||
PreviewString,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
142
comfy_extras/nodes/nodes_textdiffusers.py
Normal file
142
comfy_extras/nodes/nodes_textdiffusers.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Adapted from https://github.com/microsoft/unilm/blob/master/textdiffuser-2/inference_textdiffuser2_t2i_full.py#L334
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) Microsoft Corporation
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.sd import CLIP
|
||||
from comfy.sd1_clip import SDTokenizer
|
||||
|
||||
|
||||
class TextDiffuserTokens(CustomNode):
|
||||
ALPHABET = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(alphabet) = 95
|
||||
TOKENS = []
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, clip: CLIP):
|
||||
if len(TextDiffuserTokens.TOKENS) == 0:
|
||||
for i in range(520):
|
||||
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')
|
||||
TextDiffuserTokens.TOKENS.append(f't{i}</w>')
|
||||
TextDiffuserTokens.TOKENS.append(f'r{i}</w>')
|
||||
TextDiffuserTokens.TOKENS.append(f'b{i}</w>')
|
||||
for c in TextDiffuserTokens.ALPHABET:
|
||||
TextDiffuserTokens.TOKENS.append(f'[{c}]</w>')
|
||||
tokenizer: SDTokenizer = clip.tokenizer.sd_tokenizer
|
||||
existing_vocab = frozenset(tokenizer.tokenizer.get_vocab().keys())
|
||||
tokens = [t for t in TextDiffuserTokens.TOKENS if t not in existing_vocab]
|
||||
if len(tokens) != 0:
|
||||
tokenizer.add_tokens(tokens)
|
||||
|
||||
# todo: assert that the clip's vocab size is what we expect
|
||||
return clip,
|
||||
|
||||
|
||||
class TextDiffuserPrepare(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
},
|
||||
"optional": {
|
||||
"text": ("STRING", {"default": "", "multiline": True})
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = "STRING",
|
||||
RETURN_NAMES = "INSTRUCT STRING",
|
||||
|
||||
def execute(self, prompt: str, text: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult:
|
||||
keywords = text.split("\n")
|
||||
if len(keywords) > 0:
|
||||
# text diffusers does indeed format keywords as
|
||||
# ['some', 'word']
|
||||
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}'
|
||||
else:
|
||||
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}'
|
||||
|
||||
return message,
|
||||
|
||||
|
||||
class TextDiffuserDecodeLayout(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"layout_model": ("MODEL", {}),
|
||||
"clip": ("CLIP", {}),
|
||||
"prompt": ("STRING", {}),
|
||||
"instruct_response": ("STRING", {})
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = "STRING",
|
||||
RETURN_NAMES = "CLIP STRING",
|
||||
|
||||
def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str, *args, **kwargs) -> ValidatedNodeResult:
|
||||
current_ocr = instruct_response.split('\n')
|
||||
words = [clip.tokenizer.sd_tokenizer.tokenizer.eos_token, clip.tokenizer.sd_tokenizer.tokenizer.bos_token]
|
||||
for ocr in current_ocr:
|
||||
ocr = ocr.strip()
|
||||
|
||||
# .com ??
|
||||
if len(ocr) == 0 or '###' in ocr or '.com' in ocr:
|
||||
continue
|
||||
|
||||
items = ocr.split()
|
||||
pred = ' '.join(items[:-1])
|
||||
box = items[-1]
|
||||
|
||||
l, t, r, b = map(int, box.split(','))
|
||||
words.extend([f'l{l}', f't{t}', f'r{r}', f'b{b}'])
|
||||
|
||||
char_list = [f'[{i}]' for i in pred]
|
||||
words.extend(char_list)
|
||||
words.append(clip.tokenizer.sd_tokenizer.tokenizer.eos_token)
|
||||
return prompt + ' ' + ' '.join(words),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
TextDiffuserDecodeLayout,
|
||||
TextDiffuserPrepare,
|
||||
TextDiffuserTokens,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
@ -5,7 +5,11 @@ torchsde>=0.2.6
|
||||
einops>=0.6.0
|
||||
open-clip-torch>=2.16.0
|
||||
transformers>=4.29.1
|
||||
peft
|
||||
torchinfo
|
||||
fschat[model_worker]
|
||||
safetensors>=0.3.0
|
||||
bitsandbytes
|
||||
pytorch-lightning>=2.0.0
|
||||
aiohttp>=3.8.4
|
||||
accelerate>=0.25.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user