mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Merge upstream
This commit is contained in:
commit
00728eb20f
@ -39,7 +39,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
|
|||||||
# patch nodes
|
# patch nodes
|
||||||
from ..nodes.base_nodes import SaveImage, CLIPTextEncode, LoraLoader, CheckpointLoaderSimple
|
from ..nodes.base_nodes import SaveImage, CLIPTextEncode, LoraLoader, CheckpointLoaderSimple
|
||||||
from ..cmd.execution import PromptQueue
|
from ..cmd.execution import PromptQueue
|
||||||
from comfy.component_model.queue_types import QueueItem
|
from ..component_model.queue_types import QueueItem
|
||||||
|
|
||||||
prompt_queue_put = PromptQueue.put
|
prompt_queue_put = PromptQueue.put
|
||||||
|
|
||||||
|
|||||||
@ -176,6 +176,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||||
parser.add_argument("--external-address", required=False,
|
parser.add_argument("--external-address", required=False,
|
||||||
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
|
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
@ -208,6 +209,12 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
|
|||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
|
logging_level = logging.WARNING
|
||||||
|
if args.verbose:
|
||||||
|
logging_level = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||||
|
|
||||||
return Configuration(**vars(args))
|
return Configuration(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import TypeAlias, Union
|
from typing import TypeAlias, Union
|
||||||
|
|
||||||
from comfy.api.components.schema.prompt import PromptDict, Prompt
|
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||||
|
|
||||||
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
|
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
|
||||||
_BASE_PROMPT: JSON = {
|
_BASE_PROMPT: JSON = {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from .ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from . import ops
|
from . import ops
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
m, u = clip.load_sd(sd)
|
m, u = clip.load_sd(sd)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("missing clip vision:", m)
|
logging.warning("missing clip vision: {}".format(m))
|
||||||
u = set(u)
|
u = set(u)
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
|||||||
@ -53,8 +53,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extr
|
|||||||
if h[x] == "PROMPT":
|
if h[x] == "PROMPT":
|
||||||
input_data_all[x] = [prompt]
|
input_data_all[x] = [prompt]
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
if "extra_pnginfo" in extra_data:
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations # for Python 3.7-3.9
|
|||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
from typing import Optional, Literal, Protocol, TypeAlias, Union
|
from typing import Optional, Literal, Protocol, TypeAlias, Union
|
||||||
|
|
||||||
from comfy.component_model.queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
|
|
||||||
|
|
||||||
class ExecInfo(TypedDict):
|
class ExecInfo(TypedDict):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -368,7 +369,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
|
|
||||||
leftover_keys = controlnet_data.keys()
|
leftover_keys = controlnet_data.keys()
|
||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
print("leftover keys:", leftover_keys)
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
@ -383,7 +384,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
else:
|
else:
|
||||||
net = load_t2i_adapter(controlnet_data)
|
net = load_t2i_adapter(controlnet_data)
|
||||||
if net is None:
|
if net is None:
|
||||||
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
|
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
@ -418,7 +419,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
cd = controlnet_data[x]
|
cd = controlnet_data[x]
|
||||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||||
else:
|
else:
|
||||||
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
@ -427,7 +428,12 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||||
else:
|
else:
|
||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
print(missing, unexpected)
|
|
||||||
|
if len(missing) > 0:
|
||||||
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
logging.info("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
@ -537,9 +543,9 @@ def load_t2i_adapter(t2i_data):
|
|||||||
|
|
||||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
print("t2i missing", missing)
|
logging.warning("t2i missing {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
print("t2i unexpected", unexpected)
|
logging.info("t2i unexpected {}".format(unexpected))
|
||||||
|
|
||||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
print(f"Reshaping {k} for SD format")
|
logging.info(f"Reshaping {k} for SD format")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
from aio_pika import connect_robust
|
from aio_pika import connect_robust
|
||||||
from aio_pika.patterns import RPC
|
from aio_pika.patterns import RPC
|
||||||
|
|
||||||
from comfy.distributed.distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
|
|
||||||
|
|
||||||
class DistributedPromptClient:
|
class DistributedPromptClient:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Optional, OrderedDict, List, Dict
|
|||||||
import collections
|
import collections
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from comfy.component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
|
from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
|
||||||
|
|
||||||
|
|
||||||
class History:
|
class History:
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from .ldm.modules.attention import CrossAttention
|
from .ldm.modules.attention import CrossAttention
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from comfy.ops import manual_cast
|
from .ops import manual_cast
|
||||||
ops = manual_cast
|
ops = manual_cast
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
@ -156,7 +157,7 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
print("lora key not loaded", x)
|
logging.warning("lora key not loaded: {}".format(x))
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
@ -66,8 +67,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
self.inpaint_model = False
|
self.inpaint_model = False
|
||||||
print("model_type", model_type.name)
|
logging.warning("model_type {}".format(model_type.name))
|
||||||
print("adm", self.adm_channels)
|
logging.info("adm {}".format(self.adm_channels))
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
@ -168,7 +169,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
c_concat = kwargs.get("noise_concat", None)
|
c_concat = kwargs.get("noise_concat", None)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
|
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -182,10 +183,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("unet missing:", m)
|
logging.warning("unet missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
print("unet unexpected:", u)
|
logging.warning("unet unexpected: {}".format(u))
|
||||||
del to_load
|
del to_load
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from . import supported_models
|
from . import supported_models
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
|
import logging
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
@ -186,7 +187,7 @@ def model_config_from_unet_config(unet_config):
|
|||||||
if model_config.matches(unet_config):
|
if model_config.matches(unet_config):
|
||||||
return model_config(unet_config)
|
return model_config(unet_config)
|
||||||
|
|
||||||
print("no match", unet_config)
|
logging.error("no match {}".format(unet_config))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import psutil
|
import psutil
|
||||||
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -33,7 +34,7 @@ lowvram_available = True
|
|||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
print("Using deterministic algorithms for pytorch")
|
logging.warning("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
@ -45,7 +46,7 @@ if args.directml is not None:
|
|||||||
directml_device = torch_directml.device()
|
directml_device = torch_directml.device()
|
||||||
else:
|
else:
|
||||||
directml_device = torch_directml.device(device_index)
|
directml_device = torch_directml.device(device_index)
|
||||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||||
# torch_directml.disable_tiled_resources(True)
|
# torch_directml.disable_tiled_resources(True)
|
||||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
@ -121,10 +122,10 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
|
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
if not args.normalvram and not args.cpu:
|
if not args.normalvram and not args.cpu:
|
||||||
if lowvram_available and total_vram <= 4096:
|
if lowvram_available and total_vram <= 4096:
|
||||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -147,12 +148,10 @@ else:
|
|||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
XFORMERS_VERSION = xformers.version.__version__
|
XFORMERS_VERSION = xformers.version.__version__
|
||||||
print("xformers version:", XFORMERS_VERSION)
|
logging.warning("xformers version: {}".format(XFORMERS_VERSION))
|
||||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||||
print()
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||||
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||||
print("Please downgrade or upgrade xformers to a different version.")
|
|
||||||
print()
|
|
||||||
XFORMERS_ENABLED_VAE = False
|
XFORMERS_ENABLED_VAE = False
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -217,11 +216,11 @@ elif args.highvram or args.gpu_only:
|
|||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
FORCE_FP16 = False
|
FORCE_FP16 = False
|
||||||
if args.force_fp32:
|
if args.force_fp32:
|
||||||
print("Forcing FP32, if this improves things please report it.")
|
logging.warning("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
if args.force_fp16 or cpu_state == CPUState.MPS:
|
if args.force_fp16 or cpu_state == CPUState.MPS:
|
||||||
print("Forcing FP16.")
|
logging.warning("Forcing FP16.")
|
||||||
FORCE_FP16 = True
|
FORCE_FP16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
@ -235,12 +234,12 @@ if cpu_state != CPUState.GPU:
|
|||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.SHARED
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
logging.warning(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||||
|
|
||||||
if DISABLE_SMART_MEMORY:
|
if DISABLE_SMART_MEMORY:
|
||||||
print("Disabling smart memory management")
|
logging.warning("Disabling smart memory management")
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
@ -258,11 +257,11 @@ def get_torch_device_name(device):
|
|||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Device:", get_torch_device_name(get_torch_device()))
|
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
print("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
print("VAE dtype:", VAE_DTYPE)
|
logging.warning("VAE dtype: {}".format(VAE_DTYPE))
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
@ -305,7 +304,7 @@ class LoadedModel:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if lowvram_model_memory > 0:
|
if lowvram_model_memory > 0:
|
||||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
for m in self.real_model.modules():
|
for m in self.real_model.modules():
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -318,7 +317,7 @@ class LoadedModel:
|
|||||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_size(m)
|
mem_counter += module_size(m)
|
||||||
print("lowvram: loaded module regularly", m)
|
logging.warning("lowvram: loaded module regularly {}".format(m))
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
@ -353,7 +352,7 @@ def unload_model_clones(model):
|
|||||||
to_unload = [i] + to_unload
|
to_unload = [i] + to_unload
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
print("unload clone", i)
|
logging.warning("unload clone {}".format(i))
|
||||||
current_loaded_models.pop(i).model_unload()
|
current_loaded_models.pop(i).model_unload()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
@ -397,7 +396,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
models_already_loaded.append(loaded_model)
|
models_already_loaded.append(loaded_model)
|
||||||
else:
|
else:
|
||||||
if hasattr(x, "model"):
|
if hasattr(x, "model"):
|
||||||
print(f"Requested to load {x.model.__class__.__name__}")
|
logging.warning(f"Requested to load {x.model.__class__.__name__}")
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
@ -407,7 +406,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem, d, models_already_loaded)
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -187,7 +188,7 @@ class ModelPatcher:
|
|||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
print("could not patch. key doesn't exist in model:", key)
|
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
@ -236,7 +237,7 @@ class ModelPatcher:
|
|||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
if alpha != 0.0:
|
if alpha != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
@ -252,7 +253,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "lokr":
|
elif patch_type == "lokr":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
@ -291,7 +292,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "loha":
|
elif patch_type == "loha":
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
@ -320,7 +321,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha *= v[4] / v[0].shape[0]
|
alpha *= v[4] / v[0].shape[0]
|
||||||
@ -330,9 +331,12 @@ class ModelPatcher:
|
|||||||
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
|
||||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
else:
|
else:
|
||||||
print("patch type not recognized", patch_type, key)
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from .ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
import math
|
import math
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
|
|||||||
33
comfy/sd.py
33
comfy/sd.py
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
@ -37,7 +38,7 @@ def load_model_weights(model, sd):
|
|||||||
w = sd.pop(x)
|
w = sd.pop(x)
|
||||||
del w
|
del w
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("missing", m)
|
logging.warning("missing {}".format(m))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_clip_weights(model, sd):
|
def load_clip_weights(model, sd):
|
||||||
@ -81,7 +82,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
|||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
if (x not in k) and (x not in k1):
|
if (x not in k) and (x not in k1):
|
||||||
print("NOT LOADED", x)
|
logging.warning("NOT LOADED {}".format(x))
|
||||||
|
|
||||||
return (new_modelpatcher, new_clip)
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
@ -225,10 +226,10 @@ class VAE:
|
|||||||
|
|
||||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("Missing VAE keys", m)
|
logging.warning("Missing VAE keys {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
print("Leftover VAE keys", u)
|
logging.info("Leftover VAE keys {}".format(u))
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
@ -291,7 +292,7 @@ class VAE:
|
|||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
@ -317,7 +318,7 @@ class VAE:
|
|||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@ -393,10 +394,10 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("clip missing:", m)
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
print("clip unexpected:", u)
|
logging.info("clip unexpected: {}".format(u))
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -534,21 +535,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("clip missing:", m)
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
print("clip unexpected:", u)
|
logging.info("clip unexpected {}:".format(u))
|
||||||
else:
|
else:
|
||||||
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys:", left_over)
|
logging.info("left over keys: {}".format(left_over))
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
print("loaded straight to GPU")
|
logging.warning("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(_model_patcher)
|
model_management.load_model_gpu(_model_patcher)
|
||||||
|
|
||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
@ -577,7 +578,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
if k in sd:
|
if k in sd:
|
||||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||||
else:
|
else:
|
||||||
print(diffusers_keys[k], k)
|
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
@ -588,14 +589,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys in unet:", left_over)
|
logging.warning("left over keys in unet: {}".format(left_over))
|
||||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path):
|
def load_unet(unet_path):
|
||||||
sd = utils.load_torch_file(unet_path)
|
sd = utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd)
|
model = load_unet_state_dict(sd)
|
||||||
if model is None:
|
if model is None:
|
||||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from . import model_management
|
|||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
from . import clip_model
|
from . import clip_model
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
start_token = special_tokens.get("start", None)
|
start_token = special_tokens.get("start", None)
|
||||||
@ -140,7 +141,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_temp += [next_new_token]
|
tokens_temp += [next_new_token]
|
||||||
next_new_token += 1
|
next_new_token += 1
|
||||||
else:
|
else:
|
||||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
||||||
while len(tokens_temp) < len(x):
|
while len(tokens_temp) < len(x):
|
||||||
tokens_temp += [self.special_tokens["pad"]]
|
tokens_temp += [self.special_tokens["pad"]]
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
@ -332,9 +333,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
else:
|
else:
|
||||||
embed = torch.load(embed_path, map_location="cpu")
|
embed = torch.load(embed_path, map_location="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||||
print()
|
|
||||||
print("error loading embedding, skipping loading:", embedding_name)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if embed_out is None:
|
if embed_out is None:
|
||||||
@ -429,7 +428,7 @@ class SDTokenizer:
|
|||||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||||
embed, leftover = self._try_get_embedding(embedding_name)
|
embed, leftover = self._try_get_embedding(embedding_name)
|
||||||
if embed is None:
|
if embed is None:
|
||||||
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||||
else:
|
else:
|
||||||
if len(embed.shape) == 1:
|
if len(embed.shape) == 1:
|
||||||
tokens.append([(embed, weight)])
|
tokens.append([(embed, weight)])
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -19,14 +19,14 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
|||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||||
safe_load = False
|
safe_load = False
|
||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
logging.info(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -230,6 +230,23 @@ class SamplerDPMPP_SDE:
|
|||||||
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
class SamplerEulerAncestral:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
|
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/samplers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
|
def get_sampler(self, eta, s_noise):
|
||||||
|
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
|
||||||
|
return (sampler, )
|
||||||
|
|
||||||
class SamplerCustom:
|
class SamplerCustom:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -290,6 +307,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"VPScheduler": VPScheduler,
|
"VPScheduler": VPScheduler,
|
||||||
"SDTurboScheduler": SDTurboScheduler,
|
"SDTurboScheduler": SDTurboScheduler,
|
||||||
"KSamplerSelect": KSamplerSelect,
|
"KSamplerSelect": KSamplerSelect,
|
||||||
|
"SamplerEulerAncestral": SamplerEulerAncestral,
|
||||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||||
"SplitSigmas": SplitSigmas,
|
"SplitSigmas": SplitSigmas,
|
||||||
|
|||||||
@ -86,6 +86,50 @@ class CLIPMergeSimple:
|
|||||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPSubtract:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip1": ("CLIP",),
|
||||||
|
"clip2": ("CLIP",),
|
||||||
|
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "merge"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def merge(self, clip1, clip2, multiplier):
|
||||||
|
m = clip1.clone()
|
||||||
|
kp = clip2.get_key_patches()
|
||||||
|
for k in kp:
|
||||||
|
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||||
|
continue
|
||||||
|
m.add_patches({k: kp[k]}, - multiplier, multiplier)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPAdd:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip1": ("CLIP",),
|
||||||
|
"clip2": ("CLIP",),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "merge"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def merge(self, clip1, clip2):
|
||||||
|
m = clip1.clone()
|
||||||
|
kp = clip2.get_key_patches()
|
||||||
|
for k in kp:
|
||||||
|
if k.endswith(".position_ids") or k.endswith(".logit_scale"):
|
||||||
|
continue
|
||||||
|
m.add_patches({k: kp[k]}, 1.0, 1.0)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
class ModelMergeBlocks:
|
class ModelMergeBlocks:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -278,6 +322,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeAdd": ModelAdd,
|
"ModelMergeAdd": ModelAdd,
|
||||||
"CheckpointSave": CheckpointSave,
|
"CheckpointSave": CheckpointSave,
|
||||||
"CLIPMergeSimple": CLIPMergeSimple,
|
"CLIPMergeSimple": CLIPMergeSimple,
|
||||||
|
"CLIPMergeSubtract": CLIPSubtract,
|
||||||
|
"CLIPMergeAdd": CLIPAdd,
|
||||||
"CLIPSave": CLIPSave,
|
"CLIPSave": CLIPSave,
|
||||||
"VAESave": VAESave,
|
"VAESave": VAESave,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user