Merge upstream

This commit is contained in:
doctorpangloss 2024-03-11 09:32:57 -07:00
commit 00728eb20f
23 changed files with 161 additions and 77 deletions

View File

@ -39,7 +39,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
# patch nodes # patch nodes
from ..nodes.base_nodes import SaveImage, CLIPTextEncode, LoraLoader, CheckpointLoaderSimple from ..nodes.base_nodes import SaveImage, CLIPTextEncode, LoraLoader, CheckpointLoaderSimple
from ..cmd.execution import PromptQueue from ..cmd.execution import PromptQueue
from comfy.component_model.queue_types import QueueItem from ..component_model.queue_types import QueueItem
prompt_queue_put = PromptQueue.put prompt_queue_put = PromptQueue.put

View File

@ -176,6 +176,7 @@ def create_parser() -> argparse.ArgumentParser:
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
parser.add_argument("--external-address", required=False, parser.add_argument("--external-address", required=False,
help="Specifies a base URL for external addresses reported by the API, such as for image paths.") help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
# now give plugins a chance to add configuration # now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'): for entry_point in entry_points().select(group='comfyui.custom_config'):
@ -208,6 +209,12 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
if args.disable_auto_launch: if args.disable_auto_launch:
args.auto_launch = False args.auto_launch = False
logging_level = logging.WARNING
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)
return Configuration(**vars(args)) return Configuration(**vars(args))

View File

@ -1,7 +1,7 @@
import copy import copy
from typing import TypeAlias, Union from typing import TypeAlias, Union
from comfy.api.components.schema.prompt import PromptDict, Prompt from ..api.components.schema.prompt import PromptDict, Prompt
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
_BASE_PROMPT: JSON = { _BASE_PROMPT: JSON = {

View File

@ -1,5 +1,5 @@
import torch import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from .ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):

View File

@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
import os import os
import torch import torch
import json import json
import logging
from . import ops from . import ops
from . import model_patcher from . import model_patcher
@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0: if len(m) > 0:
print("missing clip vision:", m) logging.warning("missing clip vision: {}".format(m))
u = set(u) u = set(u)
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:

View File

@ -53,8 +53,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extr
if h[x] == "PROMPT": if h[x] == "PROMPT":
input_data_all[x] = [prompt] input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID": if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id] input_data_all[x] = [unique_id]
return input_data_all return input_data_all

View File

@ -3,7 +3,7 @@ from __future__ import annotations # for Python 3.7-3.9
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol, TypeAlias, Union from typing import Optional, Literal, Protocol, TypeAlias, Union
from comfy.component_model.queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
class ExecInfo(TypedDict): class ExecInfo(TypedDict):

View File

@ -1,6 +1,7 @@
import torch import torch
import math import math
import os import os
import logging
from . import utils from . import utils
from . import model_management from . import model_management
@ -368,7 +369,7 @@ def load_controlnet(ckpt_path, model=None):
leftover_keys = controlnet_data.keys() leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0: if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys) logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd controlnet_data = new_sd
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
@ -383,7 +384,7 @@ def load_controlnet(ckpt_path, model=None):
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data)
if net is None: if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path) logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net return net
if controlnet_config is None: if controlnet_config is None:
@ -418,7 +419,7 @@ def load_controlnet(ckpt_path, model=None):
cd = controlnet_data[x] cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device) cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else: else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
@ -427,7 +428,12 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = w.load_state_dict(controlnet_data, strict=False) missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else: else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.info("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]
@ -537,9 +543,9 @@ def load_t2i_adapter(t2i_data):
missing, unexpected = model_ad.load_state_dict(t2i_data) missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0: if len(missing) > 0:
print("t2i missing", missing) logging.warning("t2i missing {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
print("t2i unexpected", unexpected) logging.info("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)

View File

@ -1,5 +1,6 @@
import re import re
import torch import torch
import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format") logging.info(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict

View File

@ -6,7 +6,7 @@ from typing import Optional
from aio_pika import connect_robust from aio_pika import connect_robust
from aio_pika.patterns import RPC from aio_pika.patterns import RPC
from comfy.distributed.distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
class DistributedPromptClient: class DistributedPromptClient:

View File

@ -5,7 +5,7 @@ from typing import Optional, OrderedDict, List, Dict
import collections import collections
from itertools import islice from itertools import islice
from comfy.component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
class History: class History:

View File

@ -4,7 +4,7 @@ import torch
from torch import nn from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention
from inspect import isfunction from inspect import isfunction
from comfy.ops import manual_cast from .ops import manual_cast
ops = manual_cast ops = manual_cast
def exists(val): def exists(val):

View File

@ -1,3 +1,4 @@
import logging
from . import utils from . import utils
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
@ -156,7 +157,7 @@ def load_lora(lora, to_load):
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) logging.warning("lora key not loaded: {}".format(x))
return patch_dict return patch_dict
def model_lora_keys_clip(model, key_map={}): def model_lora_keys_clip(model, key_map={}):

View File

@ -1,4 +1,5 @@
import torch import torch
import logging
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
@ -66,8 +67,8 @@ class BaseModel(torch.nn.Module):
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False self.inpaint_model = False
print("model_type", model_type.name) logging.warning("model_type {}".format(model_type.name))
print("adm", self.adm_channels) logging.info("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
@ -168,7 +169,7 @@ class BaseModel(torch.nn.Module):
c_concat = kwargs.get("noise_concat", None) c_concat = kwargs.get("noise_concat", None)
if c_concat is not None: if c_concat is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(data) out['c_concat'] = conds.CONDNoiseShape(data)
return out return out
@ -182,10 +183,10 @@ class BaseModel(torch.nn.Module):
to_load = self.model_config.process_unet_state_dict(to_load) to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
print("unet missing:", m) logging.warning("unet missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
print("unet unexpected:", u) logging.warning("unet unexpected: {}".format(u))
del to_load del to_load
return self return self

View File

@ -1,5 +1,6 @@
from . import supported_models from . import supported_models
from . import supported_models_base from . import supported_models_base
import logging
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
@ -186,7 +187,7 @@ def model_config_from_unet_config(unet_config):
if model_config.matches(unet_config): if model_config.matches(unet_config):
return model_config(unet_config) return model_config(unet_config)
print("no match", unet_config) logging.error("no match {}".format(unet_config))
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):

View File

@ -1,4 +1,5 @@
import psutil import psutil
import logging
from enum import Enum from enum import Enum
from .cli_args import args from .cli_args import args
from . import utils from . import utils
@ -33,7 +34,7 @@ lowvram_available = True
xpu_available = False xpu_available = False
if args.deterministic: if args.deterministic:
print("Using deterministic algorithms for pytorch") logging.warning("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True) torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_enabled = False
@ -45,7 +46,7 @@ if args.directml is not None:
directml_device = torch_directml.device() directml_device = torch_directml.device()
else: else:
directml_device = torch_directml.device(device_index) directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index)) logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@ -121,10 +122,10 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096: if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
try: try:
@ -147,12 +148,10 @@ else:
pass pass
try: try:
XFORMERS_VERSION = xformers.version.__version__ XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION) logging.warning("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"): if XFORMERS_VERSION.startswith("0.0.18"):
print() logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") logging.warning("Please downgrade or upgrade xformers to a different version.\n")
print("Please downgrade or upgrade xformers to a different version.")
print()
XFORMERS_ENABLED_VAE = False XFORMERS_ENABLED_VAE = False
except: except:
pass pass
@ -217,11 +216,11 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False FORCE_FP32 = False
FORCE_FP16 = False FORCE_FP16 = False
if args.force_fp32: if args.force_fp32:
print("Forcing FP32, if this improves things please report it.") logging.warning("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if args.force_fp16 or cpu_state == CPUState.MPS: if args.force_fp16 or cpu_state == CPUState.MPS:
print("Forcing FP16.") logging.warning("Forcing FP16.")
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
@ -235,12 +234,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}") logging.warning(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
print("Disabling smart memory management") logging.warning("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
@ -258,11 +257,11 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
print("Device:", get_torch_device_name(get_torch_device())) logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
print("Could not pick default device.") logging.warning("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE) logging.warning("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = [] current_loaded_models = []
@ -305,7 +304,7 @@ class LoadedModel:
raise e raise e
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0 mem_counter = 0
for m in self.real_model.modules(): for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
@ -318,7 +317,7 @@ class LoadedModel:
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device) m.to(self.device)
mem_counter += module_size(m) mem_counter += module_size(m)
print("lowvram: loaded module regularly", m) logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True self.model_accelerated = True
@ -353,7 +352,7 @@ def unload_model_clones(model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
for i in to_unload: for i in to_unload:
print("unload clone", i) logging.warning("unload clone {}".format(i))
current_loaded_models.pop(i).model_unload() current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
@ -397,7 +396,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model) models_already_loaded.append(loaded_model)
else: else:
if hasattr(x, "model"): if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}") logging.warning(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: if len(models_to_load) == 0:
@ -407,7 +406,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:

View File

@ -1,6 +1,7 @@
import torch import torch
import copy import copy
import inspect import inspect
import logging
from . import utils from . import utils
from . import model_management from . import model_management
@ -187,7 +188,7 @@ class ModelPatcher:
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:
print("could not patch. key doesn't exist in model:", key) logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue continue
weight = model_sd[key] weight = model_sd[key]
@ -236,7 +237,7 @@ class ModelPatcher:
w1 = v[0] w1 = v[0]
if alpha != 0.0: if alpha != 0.0:
if w1.shape != weight.shape: if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else: else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype) weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
@ -252,7 +253,7 @@ class ModelPatcher:
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr": elif patch_type == "lokr":
w1 = v[0] w1 = v[0]
w2 = v[1] w2 = v[1]
@ -291,7 +292,7 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha": elif patch_type == "loha":
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
@ -320,7 +321,7 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
if v[4] is not None: if v[4] is not None:
alpha *= v[4] / v[0].shape[0] alpha *= v[4] / v[0].shape[0]
@ -330,9 +331,12 @@ class ModelPatcher:
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: else:
print("patch type not recognized", patch_type, key) logging.warning("patch type not recognized {} {}".format(patch_type, key))
return weight return weight

View File

@ -1,5 +1,5 @@
import torch import torch
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from .ldm.modules.diffusionmodules.util import make_beta_schedule
import math import math
class EPS: class EPS:

View File

@ -1,5 +1,6 @@
import torch import torch
from enum import Enum from enum import Enum
import logging
from . import model_management from . import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
@ -37,7 +38,7 @@ def load_model_weights(model, sd):
w = sd.pop(x) w = sd.pop(x)
del w del w
if len(m) > 0: if len(m) > 0:
print("missing", m) logging.warning("missing {}".format(m))
return model return model
def load_clip_weights(model, sd): def load_clip_weights(model, sd):
@ -81,7 +82,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
k1 = set(k1) k1 = set(k1)
for x in loaded: for x in loaded:
if (x not in k) and (x not in k1): if (x not in k) and (x not in k1):
print("NOT LOADED", x) logging.warning("NOT LOADED {}".format(x))
return (new_modelpatcher, new_clip) return (new_modelpatcher, new_clip)
@ -225,10 +226,10 @@ class VAE:
m, u = self.first_stage_model.load_state_dict(sd, strict=False) m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: if len(m) > 0:
print("Missing VAE keys", m) logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0: if len(u) > 0:
print("Leftover VAE keys", u) logging.info("Leftover VAE keys {}".format(u))
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
@ -291,7 +292,7 @@ class VAE:
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
@ -317,7 +318,7 @@ class VAE:
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
return samples return samples
@ -393,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
print("clip missing:", m) logging.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
print("clip unexpected:", u) logging.info("clip unexpected: {}".format(u))
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
@ -534,21 +535,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip = CLIP(clip_target, embedding_directory=embedding_directory) clip = CLIP(clip_target, embedding_directory=embedding_directory)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
print("clip missing:", m) logging.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
print("clip unexpected:", u) logging.info("clip unexpected {}:".format(u))
else: else:
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) logging.info("left over keys: {}".format(left_over))
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") logging.warning("loaded straight to GPU")
model_management.load_model_gpu(_model_patcher) model_management.load_model_gpu(_model_patcher)
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
@ -577,7 +578,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
if k in sd: if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k) new_sd[diffusers_keys[k]] = sd.pop(k)
else: else:
print(diffusers_keys[k], k) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
@ -588,14 +589,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys in unet:", left_over) logging.warning("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path): def load_unet(unet_path):
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd) model = load_unet_state_dict(sd)
if model is None: if model is None:
print("ERROR UNSUPPORTED UNET", unet_path) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model return model

View File

@ -9,6 +9,7 @@ from . import model_management
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import clip_model from . import clip_model
import json import json
import logging
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -140,7 +141,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_temp += [next_new_token] tokens_temp += [next_new_token]
next_new_token += 1 next_new_token += 1
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
while len(tokens_temp) < len(x): while len(tokens_temp) < len(x):
tokens_temp += [self.special_tokens["pad"]] tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
@ -332,9 +333,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
else: else:
embed = torch.load(embed_path, map_location="cpu") embed = torch.load(embed_path, map_location="cpu")
except Exception as e: except Exception as e:
print(traceback.format_exc()) logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
print()
print("error loading embedding, skipping loading:", embedding_name)
return None return None
if embed_out is None: if embed_out is None:
@ -429,7 +428,7 @@ class SDTokenizer:
embedding_name = word[len(self.embedding_identifier):].strip('\n') embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name) embed, leftover = self._try_get_embedding(embedding_name)
if embed is None: if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring") logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else: else:
if len(embed.shape) == 1: if len(embed.shape) == 1:
tokens.append([(embed, weight)]) tokens.append([(embed, weight)])

View File

@ -9,7 +9,7 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from contextlib import contextmanager
import logging
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
@ -19,14 +19,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
else: else:
if safe_load: if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames: if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False safe_load = False
if safe_load: if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") logging.info(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:

View File

@ -230,6 +230,23 @@ class SamplerDPMPP_SDE:
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerEulerAncestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerCustom: class SamplerCustom:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -290,6 +307,7 @@ NODE_CLASS_MAPPINGS = {
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerEulerAncestral": SamplerEulerAncestral,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,

View File

@ -86,6 +86,50 @@ class CLIPMergeSimple:
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, ) return (m, )
class CLIPSubtract:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, - multiplier, multiplier)
return (m, )
class CLIPAdd:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip1": ("CLIP",),
"clip2": ("CLIP",),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2):
m = clip1.clone()
kp = clip2.get_key_patches()
for k in kp:
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
continue
m.add_patches({k: kp[k]}, 1.0, 1.0)
return (m, )
class ModelMergeBlocks: class ModelMergeBlocks:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -278,6 +322,8 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd, "ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave, "CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave, "CLIPSave": CLIPSave,
"VAESave": VAESave, "VAESave": VAESave,
} }