LLM support in ComfyUI

- Currently uses `transformers`
 - Supports model management and correctly loading and unloading models
   based on what your machine can support
 - Includes a Text Diffusers 2 workflow to demonstrate text rendering in
   SD1.5
This commit is contained in:
doctorpangloss 2024-05-14 17:30:23 -07:00
parent 0ee2f3bf15
commit 8741cb3ce8
20 changed files with 893 additions and 213 deletions

View File

@ -87,9 +87,10 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
vocab_size = config_dict["vocab_size"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)

View File

@ -1,14 +1,18 @@
from __future__ import annotations
import os
from . import sd, utils
def first_file(path, filenames):
def first_file(path, filenames) -> str | None:
for f in filenames:
p = os.path.join(path, f)
if os.path.exists(p):
return p
return str(p)
return None
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
@ -22,15 +26,17 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = sd.load_unet(unet_path)
if unet_path is not None:
unet = sd.load_unet(unet_path)
clip = None
if output_clip:
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
if output_clip and not all(te is None for te in text_encoder_paths):
clip = sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config1)
vae = None
if output_vae:
if output_vae and vae_path is not None:
_sd = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=_sd)
return (unet, clip, vae)
return unet, clip, vae

View File

@ -0,0 +1,5 @@
from fastchat.model.model_adapter import register_model_adapter
from .fastchat_adapters import Phi3Adapter
register_model_adapter(Phi3Adapter)

View File

@ -0,0 +1,62 @@
from __future__ import annotations
from typing import Optional
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
class Phi3Adapter(BaseModelAdapter):
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
def match(self, model_path: str):
return "phi-3-mini-128k-instruct" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
self.model = model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
if input:
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
else:
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
return prompt
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
prompt = self.generate_prompt(messages[-1]["content"])
for i in range(len(messages) - 2, -1, -1):
if messages[i]["role"] == "user":
prompt = self.generate_prompt(messages[i]["content"]) + prompt
elif messages[i]["role"] == "assistant":
prompt = messages[i]["content"] + prompt
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(
input_ids,
**generation_kwargs
)
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output = output.replace(prompt, "").strip()
return output
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phi-3-mini")

View File

@ -0,0 +1,8 @@
from __future__ import annotations
from typing import NamedTuple, Dict, Any
class ProcArgsRes(NamedTuple):
seed: int
generate_kwargs: Dict[str, Any]

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import warnings
from typing import Optional, Any
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from ..model_management import unet_offload_device, get_torch_device
from ..model_management_types import ModelManageable
class TransformersManagedModel(ModelManageable):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None):
self.repo_id = repo_id
self.model = model
self.tokenizer = tokenizer
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
if model.device != self.offload_device:
model.to(device=self.offload_device)
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@property
def current_device(self) -> torch.device:
return self.model.device
def is_clone(self, other: Any) -> bool:
return hasattr(other, "model") and self.model is other.model
def clone_has_same_weights(self, clone: Any) -> bool:
if not isinstance(clone, TransformersManagedModel):
return False
clone: TransformersManagedModel
if not self.is_clone(clone):
return False
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
def model_size(self) -> int:
return self._size
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(arg)
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=offload_device)

View File

@ -7,7 +7,7 @@ from os.path import join
from typing import List, Any, Optional, Union
import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, scan_cache_dir
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
@ -167,6 +167,7 @@ KNOWN_CHECKPOINTS = [
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
]
KNOWN_UNCLIP_CHECKPOINTS = [
@ -297,6 +298,12 @@ KNOWN_VAES = [
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
]
KNOWN_HUGGINGFACE_MODEL_REPOS = {
"JingyeChen22/textdiffuser2_layout_planner",
'JingyeChen22/textdiffuser2-full-ft',
"microsoft/Phi-3-mini-4k-instruct",
}
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models:
@ -304,3 +311,10 @@ def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile
symbol += models
folder_paths.invalidate_cache(folder_name)
return symbol
def huggingface_repos() -> List[str]:
cache_info = scan_cache_dir()
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
return list(existing_repo_ids | known_repo_ids)

View File

@ -1,34 +1,36 @@
from __future__ import annotations
import logging
import sys
from enum import Enum
from threading import RLock
from typing import Literal
import psutil
import logging
from enum import Enum
from .cli_args import args
from . import interruption
from threading import RLock
import torch
import sys
from . import interruption
from .cli_args import args
from .model_management_types import ModelManageable
model_management_lock = RLock()
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
DISABLED = 0 # No vram present: no need to move models to vram
NO_VRAM = 1 # Very low vram: enable all the options to save vram
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
SHARED = 5 # No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
@ -46,6 +48,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
import torch_directml
directml_enabled = True
device_index = args.directml
if device_index < 0:
@ -54,10 +57,11 @@ if args.directml is not None:
directml_device = torch_directml.device(device_index)
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
except:
@ -73,6 +77,7 @@ except:
if args.cpu:
cpu_state = CPUState.CPU
def is_intel_xpu():
global cpu_state
global xpu_available
@ -81,6 +86,7 @@ def is_intel_xpu():
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@ -97,6 +103,7 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@ -107,7 +114,7 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
@ -126,6 +133,7 @@ def get_total_memory(dev=None, torch_total_too=False):
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
@ -147,6 +155,7 @@ else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
@ -164,6 +173,7 @@ else:
except:
XFORMERS_IS_AVAILABLE = False
def is_nvidia():
global cpu_state
if cpu_state == CPUState.GPU:
@ -171,6 +181,7 @@ def is_nvidia():
return True
return False
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
@ -205,7 +216,6 @@ elif args.bf16_vae:
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
@ -233,7 +243,6 @@ if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED
@ -247,6 +256,7 @@ DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY:
logging.info("Disabling smart memory management")
def get_torch_device_name(device):
if hasattr(device, 'type'):
if device.type == "cuda":
@ -262,6 +272,7 @@ def get_torch_device_name(device):
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
@ -271,6 +282,7 @@ logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -279,6 +291,7 @@ def module_size(module):
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
def __init__(self, model: ModelManageable):
self.model = model
@ -328,9 +341,11 @@ class LoadedModel:
def __eq__(self, other):
return self.model is other.model
def minimum_inference_memory():
return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
to_unload = []
@ -361,12 +376,13 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> b
return unload_weight
def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock:
unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1):
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
@ -391,6 +407,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state
@ -424,7 +441,7 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False): # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
@ -432,7 +449,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
@ -447,8 +464,8 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3))
if model_size > (current_free_mem - inference_memory): # only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
@ -465,6 +482,7 @@ def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model])
def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
@ -472,8 +490,8 @@ def cleanup_models(keep_clone_weights_loaded=False):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
# TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete:
@ -481,6 +499,7 @@ def cleanup_models(keep_clone_weights_loaded=False):
x.model_unload()
del x
def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
@ -490,17 +509,19 @@ def dtype_size(dtype):
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
except: # Old pytorch doesn't have .itemsize
pass
return dtype_size
def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
def unet_initial_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM:
return torch_dev
@ -518,7 +539,8 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@ -535,8 +557,9 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.bfloat16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if weight_dtype == torch.float32:
return None
@ -556,12 +579,14 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
else:
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
@ -573,6 +598,7 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
@ -595,27 +621,32 @@ def intermediate_device():
else:
return torch.device("cpu")
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device()
def vae_offload_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def vae_dtype():
global VAE_DTYPE
return VAE_DTYPE
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
return "cuda"
def supports_dtype(device, dtype): #TODO
def supports_dtype(device, dtype): # TODO
if dtype == torch.float32:
return True
if is_device_cpu(device):
@ -626,12 +657,14 @@ def supports_dtype(device, dtype): #TODO
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return False # pytorch bug? mps doesn't support non blocking
return False
# return True #TODO: figure out why this causes issues
def cast_to_device(tensor, device, dtype, copy=False):
with model_management_lock:
device_supports_cast = False
@ -655,6 +688,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
global cpu_state
@ -674,18 +708,21 @@ def xformers_enabled_vae():
return XFORMERS_ENABLED_VAE
def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
#TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia
# TODO: more reliable way of checking for flash attention?
if is_nvidia(): # pytorch flash attention only works on Nvidia
return True
return False
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
@ -696,7 +733,7 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_free_total
else:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
@ -718,29 +755,36 @@ def get_free_memory(dev=None, torch_free_too=False):
else:
return mem_free_total
def cpu_mode():
global cpu_state
return cpu_state == CPUState.CPU
def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def is_device_type(device, type):
if hasattr(device, 'type'):
if (device.type == type):
return True
return False
def is_device_cpu(device):
return is_device_type(device, 'cpu')
def is_device_mps(device):
return is_device_type(device, 'mps')
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
@ -781,9 +825,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
fp16_works = False
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
#when the model doesn't actually fit on the card
#TODO: actually test if GP106 and others have the same type of behavior
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
# when the model doesn't actually fit on the card
# TODO: actually test if GP106 and others have the same type of behavior
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
@ -797,7 +841,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if props.major < 7:
return False
#FP16 is just broken on these cards
# FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series:
if x in props.name:
@ -805,12 +849,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None: #TODO not sure about mps bf16 support
if device is not None: # TODO not sure about mps bf16 support
if is_device_mps(device):
return False
@ -842,6 +887,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def soft_empty_cache(force=False):
with model_management_lock:
global cpu_state
@ -850,16 +896,17 @@ def soft_empty_cache(force=False):
elif is_intel_xpu():
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
with model_management_lock:
free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
def resolve_lowvram_weight(weight, model, key): # TODO: remove
return weight

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Protocol, Optional
from typing import Protocol, Optional, Any
import torch
@ -18,13 +18,12 @@ class ModelManageable(Protocol):
load_device: torch.device
offload_device: torch.device
model: torch.nn.Module
current_device: torch.device
@property
def dtype(self) -> torch.dtype:
def current_device(self) -> torch.device:
...
def is_clone(self, other: torch.nn.Module) -> bool:
def is_clone(self, other: Any) -> bool:
...
def clone_has_same_weights(self, clone: torch.nn.Module) -> bool:

View File

@ -1,11 +1,12 @@
import torch
import copy
import inspect
import logging
import uuid
from . import utils
import torch
from . import model_management
from . import utils
from .model_management_types import ModelManageable
@ -20,6 +21,7 @@ def apply_weight_decompose(dora_scale, weight):
return weight * (dora_scale / weight_norm)
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -41,6 +43,7 @@ def set_model_options_patch_replace(model_options, patch, name, block_name, numb
model_options["transformer_options"] = to
return model_options
class ModelPatcher(ModelManageable):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@ -49,14 +52,15 @@ class ModelPatcher(ModelManageable):
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_options = {"transformer_options": {}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self._current_device: torch.device
if current_device is None:
self.current_device = self.offload_device
self._current_device = self.offload_device
else:
self.current_device = current_device
self._current_device = current_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
@ -71,7 +75,7 @@ class ModelPatcher(ModelManageable):
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self._current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -107,7 +111,7 @@ class ModelPatcher(ModelManageable):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
@ -270,18 +274,20 @@ class ModelPatcher(ModelManageable):
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
self._current_device = device_to
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
@ -325,7 +331,7 @@ class ModelPatcher(ModelManageable):
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 1:
patch_type = "diff"
@ -340,14 +346,14 @@ class ModelPatcher(ModelManageable):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
elif patch_type == "lora": # lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
@ -407,7 +413,7 @@ class ModelPatcher(ModelManageable):
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
@ -478,10 +484,14 @@ class ModelPatcher(ModelManageable):
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
self._current_device = value = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear()
@property
def current_device(self) -> torch.device:
return self._current_device

View File

@ -9,6 +9,7 @@ import logging
from PIL import Image, ImageOps, ImageSequence, ImageFile
from PIL.PngImagePlugin import PngInfo
from huggingface_hub import hf_hub_download, snapshot_download
from natsort import natsorted
import numpy as np
import safetensors.torch
@ -25,11 +26,13 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
from .. import node_helpers
from ..utils import comfy_tqdm
class CLIPTextEncode:
@classmethod
@ -513,11 +516,14 @@ class DiffusersLoader:
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
paths += huggingface_repos()
paths = list(frozenset(paths))
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders/deprecated"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
@ -526,6 +532,9 @@ class DiffusersLoader:
if os.path.exists(path):
model_path = path
break
if not os.path.exists(model_path):
with comfy_tqdm():
model_path = snapshot_download(model_path)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
@ -1029,7 +1038,7 @@ class LatentFromBatch:
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
@ -1044,7 +1053,7 @@ class RepeatLatentBatch:
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
@ -1374,7 +1383,7 @@ class SaveImage:
@classmethod
def INPUT_TYPES(s):
return {"required":
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
@ -1450,9 +1459,9 @@ class LoadImage:
def load_image(self, image: str):
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []

View File

@ -1,31 +1,32 @@
import torch
from enum import Enum
from __future__ import annotations
import dataclasses
import logging
from enum import Enum
from typing import Any, Optional
from . import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
import torch
import yaml
from . import utils
from . import clip_vision
from . import gligen
from . import diffusers_convert
from . import gligen
from . import lora
from . import model_detection
from . import model_management
from . import model_patcher
from . import model_sampling
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import model_patcher
from . import model_sampling
from . import lora
from . import utils
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .t2i_adapter import adapter
from .taesd import taesd
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
m = set(m)
@ -40,6 +41,7 @@ def load_model_weights(model, sd):
logging.warning("missing {}".format(m))
return model
def load_clip_weights(model, sd):
k = list(sd.keys())
for x in k:
@ -87,7 +89,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False):
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None):
if no_init:
return
params = target.params.copy()
@ -98,10 +100,12 @@ class CLIP:
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
params['dtype'] = model_management.text_encoder_dtype(load_device)
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
@ -157,12 +161,13 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) # These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.upscale_ratio = 8
@ -181,16 +186,16 @@ class VAE:
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = taesd.TAESD()
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
elif "vquantizer.codebook.weight" in sd: # VQGan: stage a of stable cascade
self.first_stage_model = StageA()
self.downscale_ratio = 4
self.upscale_ratio = 4
#TODO
#self.memory_used_encode
#self.memory_used_decode
# TODO
# self.memory_used_encode
# self.memory_used_decode
self.process_input = lambda image: image
self.process_output = lambda image: image
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade
elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: # effnet: encoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
@ -198,22 +203,22 @@ class VAE:
for k in sd:
new_sd["encoder.{}".format(k)] = sd[k]
sd = new_sd
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
elif "blocks.11.num_batches_tracked" in sd: # previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
new_sd = {}
for k in sd:
new_sd["previewer.{}".format(k)] = sd[k]
sd = new_sd
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: # combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
# default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: # Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.upscale_ratio = 4
@ -261,7 +266,7 @@ class VAE:
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
@ -269,22 +274,22 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
@ -298,23 +303,23 @@ class VAE:
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x + batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1)
return output.movedim(1, -1)
def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
pixel_samples = pixel_samples.movedim(-1, 1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@ -323,8 +328,8 @@ class VAE:
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -332,16 +337,17 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
return samples
def get_sd(self):
return self.first_stage_model.state_dict()
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@ -360,26 +366,33 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data)
return StyleModel(model)
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
@dataclasses.dataclass
class CLIPTarget:
clip: Optional[Any] = None
vae: Optional[Any] = None
params: Optional[dict] = dataclasses.field(default_factory=dict)
tokenizer: Optional[Any] = None
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None):
clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))
class EmptyClass:
pass
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = utils.clip_text_transformers_convert(clip_data[i], "", "")
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) # old models saved with the CLIPSave node
clip_target = EmptyClass()
clip_target = CLIPTarget()
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
@ -399,7 +412,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@ -409,6 +422,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
logging.debug("clip unexpected: {}".format(u))
return clip
def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path, safe_load=True)
model = gligen.load_gligen(data)
@ -416,10 +430,11 @@ def load_gligen(ckpt_path):
model = model.half()
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
logging.warning("Warning: The load checkpoint with config function is deprecated and will eventually be removed, please use the other one.")
model, clip, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=output_vae, output_clip=output_clip, output_clipvision=False, embedding_directory=embedding_directory, output_model=True)
#TODO: this function is a mess and should be removed eventually
# TODO: this function is a mess and should be removed eventually
if config is None:
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
@ -430,8 +445,10 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
m = model.clone()
class ModelSamplingAdvanced(model_sampling.ModelSamplingDiscrete, model_sampling.V_PREDICTION):
pass
m.add_object_patch("model_sampling", ModelSamplingAdvanced(model.model.model_config))
model = m
@ -441,6 +458,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
@ -467,7 +485,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
inital_load_device = model_management.unet_initial_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
@ -509,18 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (_model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd): #load unet in diffusers format
def load_unet_state_dict(sd): # load unet in diffusers format
parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: # ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
else: # diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@ -546,6 +564,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
logging.info("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
@ -554,6 +573,7 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
clip_sd = None
load_models = [model]

View File

@ -1,15 +1,45 @@
import os
from __future__ import annotations
from transformers import CLIPTokenizer
from . import ops
import torch
import traceback
import zipfile
from . import model_management
from pkg_resources import resource_filename
from . import clip_model
import importlib.resources as resources
import json
import logging
import os
import traceback
import zipfile
from typing import List
import torch
from pkg_resources import resource_filename
from transformers import CLIPTokenizer
from . import clip_model
from . import model_management
from . import ops
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
config: dict | None = None
if text_model_config_or_path is None:
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
if isinstance(text_model_config_or_path, str):
if text_model_config_or_path.startswith("{"):
config = json.loads(text_model_config_or_path)
else:
if not os.path.exists(text_model_config_or_path):
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(text_model_config_or_path) as f:
config = json.load(f)
elif isinstance(text_model_config_or_path, dict):
config = text_model_config_or_path
assert config is not None
return config
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@ -23,6 +53,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
@ -46,7 +77,7 @@ class ClipTokenWeightEncoder:
output = []
for k in range(0, sections):
z = out[k:k+1]
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
@ -60,6 +91,7 @@ class ClipTokenWeightEncoder:
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
@ -67,20 +99,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, return_projected_pooled=True): # clip-vit-base-patch32
super().__init__()
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
with open(textmodel_json_config) as f:
config = json.load(f)
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers
@ -105,7 +133,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
@ -132,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp = []
for y in x:
if isinstance(y, int):
if y == token_dict_size: #EOS token
if y == token_dict_size: # EOS token
y = -1
tokens_temp += [y]
else:
@ -153,12 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
new_embedding.weight[n] = current_embeds.weight[-1] # EOS embedding
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
for x in out_tokens:
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one
return processed_tokens
@ -201,6 +229,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
result = []
current_item = ""
@ -229,6 +258,7 @@ def parse_parentheses(string):
result.append(current_item)
return result
def token_weights(string, current_weight):
a = parse_parentheses(string)
out = []
@ -240,7 +270,7 @@ def token_weights(string, current_weight):
weight *= 1.1
if xx > 0:
try:
weight = float(x[xx+1:])
weight = float(x[xx + 1:])
x = x[:xx]
except:
pass
@ -249,16 +279,19 @@ def token_weights(string, current_weight):
out += [(x, current_weight)]
return out
def escape_important(text):
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
def unescape_important(text):
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip:
names = list(filter(lambda a: "data/" in a, myzip.namelist()))
@ -267,17 +300,18 @@ def safe_load_embed_zip(embed_path):
with myzip.open(n) as myfile:
data = myfile.read()
number = len(data) // 4
length_embed = 1024 #sd2.x
length_embed = 1024 # sd2.x
if number < 768:
continue
if number % 768 == 0:
length_embed = 768 #sd1.x
length_embed = 768 # sd1.x
num_embeds = number // length_embed
embed = torch.frombuffer(data, dtype=torch.float)
out = embed.reshape((num_embeds, length_embed)).clone()
del embed
return out
def expand_directory_list(directories):
dirs = set()
for x in directories:
@ -286,6 +320,7 @@ def expand_directory_list(directories):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
@ -356,6 +391,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values))
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None:
@ -378,16 +414,20 @@ class SDTokenizer:
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.add_tokens([])
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key
def _try_get_embedding(self, embedding_name:str):
def add_tokens(self, tokens: List[str]):
if len(tokens) > 0:
self.tokenizer.add_tokens(tokens)
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
def _try_get_embedding(self, embedding_name: str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
@ -400,8 +440,7 @@ class SDTokenizer:
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
@ -417,13 +456,13 @@ class SDTokenizer:
parsed_weights = token_weights(text, 1.0)
vocab = self.tokenizer.get_vocab()
#tokenize words
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
# if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
@ -434,52 +473,54 @@ class SDTokenizer:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
# if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
# parse word
exact_word = f"{word}</w>"
if exact_word in vocab:
if word == self.tokenizer.eos_token:
tokenizer_result = [self.tokenizer.eos_token_id]
elif exact_word in vocab:
tokenizer_result = [vocab[exact_word]]
else:
tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:-1]
tokens.append([(t, weight) for t in tokenizer_result])
#reshape token array to CLIP input size
# reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
# determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
# break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
# add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
# start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
#fill last batch
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
@ -487,11 +528,10 @@ class SDTokenizer:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
@ -502,21 +542,25 @@ class SD1Tokenizer:
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
out[self.clip_name] = self.sd_tokenizer.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)
return self.sd_tokenizer.untokenize(token_weight_pair)
@property
def sd_tokenizer(self) -> SDTokenizer:
return getattr(self, self.clip)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options)

View File

@ -1,27 +1,28 @@
from pkg_resources import resource_filename
from . import sd1_clip
import os
from .sd1_clip import get_clip_config_dict
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
layer = "hidden"
layer_idx = -2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, textmodel_json_config=textmodel_json_config, **kwargs)

View File

@ -1,20 +1,23 @@
from . import sd1_clip
import torch
import os
from . import sd1_clip
from .sd1_clip import get_clip_config_dict
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
layer = "hidden"
layer_idx = -2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
@ -25,7 +28,7 @@ class SDXLTokenizer:
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
@ -34,6 +37,7 @@ class SDXLTokenizer:
def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
@ -61,28 +65,32 @@ class SDXLClipModel(torch.nn.Module):
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, textmodel_json_config=textmodel_json_config)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config)

View File

@ -1,11 +1,15 @@
from __future__ import annotations
import contextlib
import logging
import math
import os.path
import random
import struct
import sys
import warnings
from contextlib import contextmanager
from typing import Optional
from typing import Optional, Any
import numpy as np
import safetensors.torch
@ -14,6 +18,7 @@ from PIL import Image
from tqdm import tqdm
from . import checkpoint_pickle, interruption
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
@ -30,6 +35,7 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
@ -498,8 +504,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return output
def _progress_bar_update(value: float, total: float, preview_image, client_id: Optional[str] = None):
server = current_execution_context().server
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
server = server or current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id
interruption.throw_exception_if_processing_interrupted()
@ -570,15 +576,14 @@ def comfy_tqdm():
"""
_original_init = tqdm.__init__
_original_update = tqdm.update
server = current_execution_context().server
try:
def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs)
self._progress_bar = ProgressBar(self.total)
def __update(self, n=1):
assert self._progress_bar is not None
_original_update(self, n)
self._progress_bar.update(n)
_progress_bar_update(n, self.total, server=server)
tqdm.__init__ = __init
tqdm.update = __update
@ -596,3 +601,30 @@ def comfy_progress(total: float) -> ProgressBar:
yield ProgressBar(total)
else:
yield _DisabledProgressBar()
@contextlib.contextmanager
def seed_for_block(seed):
# Save the current random state
torch_rng_state = torch.get_rng_state()
random_state = random.getstate()
numpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
# Set the new seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
yield
finally:
# Restore the previous random state
torch.set_rng_state(torch_rng_state)
random.setstate(random_state)
np.random.set_state(numpy_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(cuda_rng_state)

View File

@ -0,0 +1,59 @@
/**
* Uses code adapted from https://github.com/Zuellni/ComfyUI-ExLlama-Nodes
*
* MIT License
*
* Copyright (c) 2023 Zuellni
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
import { app } from "../../scripts/app.js";
import { ComfyWidgets } from "../../scripts/widgets.js";
app.registerExtension({
name: "Comfy.StringNodes",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "PreviewString" || nodeData.name === "SaveString") {
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function ({ string }) {
onExecuted?.apply(this, arguments);
if (this.widgets) {
const index = this.widgets.findIndex((w) => w.name === "output");
if (index !== -1) {
for (let i = index; i < this.widgets.length; i++) {
this.widgets[i].onRemove?.();
}
this.widgets.length = index;
}
const options = ["STRING", { multiline: true }];
const widget = ComfyWidgets["STRING"](this, "output", options, app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.opacity = 0.7;
widget.value = string;
}
};
}
},
});

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Any, List, Dict
import torch
from fastchat.model import get_conversation_template
from transformers import AutoModelForCausalLM, AutoTokenizer
from comfy.language.language_types import ProcArgsRes
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_tqdm, seed_for_block
_transformer_args_deterministic_decoding = {
"max_length": ("INT", {"default": 4096, "min": 1}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
}
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
seed = generate_kwargs.pop("seed", 0)
return ProcArgsRes(seed, generate_kwargs)
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),)
}
}
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str):
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
return model_managed,
class SimpleBatchDecode(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
tokenizer = model.tokenizer
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
with comfy_tqdm():
with seed_for_block(seed):
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
class SimpleInstruct(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
conv = get_conversation_template(model.repo_id)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = model.tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
with seed_for_block(seed):
output_ids = model.model.generate(
**inputs,
do_sample=True,
**generate_kwargs
)
if model.model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
outputs = model.tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
return outputs,
class PreviewString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
}
}
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
def execute(self, value: str):
return {"ui": {"string": [value]}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformersLoader,
SimpleBatchDecode,
SimpleInstruct,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -0,0 +1,142 @@
"""
Adapted from https://github.com/microsoft/unilm/blob/master/textdiffuser-2/inference_textdiffuser2_t2i_full.py#L334
The MIT License (MIT)
Copyright (c) Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import string
from typing import Optional
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.sd import CLIP
from comfy.sd1_clip import SDTokenizer
class TextDiffuserTokens(CustomNode):
ALPHABET = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(alphabet) = 95
TOKENS = []
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"clip": ("CLIP",)
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "execute"
def execute(self, clip: CLIP):
if len(TextDiffuserTokens.TOKENS) == 0:
for i in range(520):
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')
TextDiffuserTokens.TOKENS.append(f't{i}</w>')
TextDiffuserTokens.TOKENS.append(f'r{i}</w>')
TextDiffuserTokens.TOKENS.append(f'b{i}</w>')
for c in TextDiffuserTokens.ALPHABET:
TextDiffuserTokens.TOKENS.append(f'[{c}]</w>')
tokenizer: SDTokenizer = clip.tokenizer.sd_tokenizer
existing_vocab = frozenset(tokenizer.tokenizer.get_vocab().keys())
tokens = [t for t in TextDiffuserTokens.TOKENS if t not in existing_vocab]
if len(tokens) != 0:
tokenizer.add_tokens(tokens)
# todo: assert that the clip's vocab size is what we expect
return clip,
class TextDiffuserPrepare(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"prompt": ("STRING", {"default": "", "multiline": True}),
},
"optional": {
"text": ("STRING", {"default": "", "multiline": True})
}
}
FUNCTION = "execute"
RETURN_TYPES = "STRING",
RETURN_NAMES = "INSTRUCT STRING",
def execute(self, prompt: str, text: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult:
keywords = text.split("\n")
if len(keywords) > 0:
# text diffusers does indeed format keywords as
# ['some', 'word']
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}'
else:
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}'
return message,
class TextDiffuserDecodeLayout(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"layout_model": ("MODEL", {}),
"clip": ("CLIP", {}),
"prompt": ("STRING", {}),
"instruct_response": ("STRING", {})
}
}
FUNCTION = "execute"
RETURN_TYPES = "STRING",
RETURN_NAMES = "CLIP STRING",
def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str, *args, **kwargs) -> ValidatedNodeResult:
current_ocr = instruct_response.split('\n')
words = [clip.tokenizer.sd_tokenizer.tokenizer.eos_token, clip.tokenizer.sd_tokenizer.tokenizer.bos_token]
for ocr in current_ocr:
ocr = ocr.strip()
# .com ??
if len(ocr) == 0 or '###' in ocr or '.com' in ocr:
continue
items = ocr.split()
pred = ' '.join(items[:-1])
box = items[-1]
l, t, r, b = map(int, box.split(','))
words.extend([f'l{l}', f't{t}', f'r{r}', f'b{b}'])
char_list = [f'[{i}]' for i in pred]
words.extend(char_list)
words.append(clip.tokenizer.sd_tokenizer.tokenizer.eos_token)
return prompt + ' ' + ' '.join(words),
NODE_CLASS_MAPPINGS = {}
for cls in (
TextDiffuserDecodeLayout,
TextDiffuserPrepare,
TextDiffuserTokens,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -5,7 +5,11 @@ torchsde>=0.2.6
einops>=0.6.0
open-clip-torch>=2.16.0
transformers>=4.29.1
peft
torchinfo
fschat[model_worker]
safetensors>=0.3.0
bitsandbytes
pytorch-lightning>=2.0.0
aiohttp>=3.8.4
accelerate>=0.25.0