mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Merge remote-tracking branch 'upstream/master' into addBatchIndex
This commit is contained in:
commit
3ce4dc988d
@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
else:
|
||||
raise AssertionError('Unknown merge analysis result')
|
||||
|
||||
|
||||
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
||||
repo = pygit2.Repository(str(sys.argv[1]))
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
|
||||
@ -30,6 +30,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
|
||||
@ -17,6 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11.3'
|
||||
|
||||
@ -1,14 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
return ModelPatcher(model), clip, vae
|
||||
|
||||
111
comfy/diffusers_load.py
Normal file
111
comfy/diffusers_load.py
Normal file
@ -0,0 +1,111 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import diffusers_convert
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
return ModelPatcher(model), clip, vae
|
||||
@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
|
||||
total_vram = 0
|
||||
total_vram_available_mb = -1
|
||||
|
||||
accelerate_enabled = False
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
@ -31,11 +30,12 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import torch
|
||||
if directml_enabled:
|
||||
total_vram = 4097 #TODO
|
||||
pass #TODO
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
@ -46,7 +46,7 @@ try:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
if not args.normalvram and not args.cpu:
|
||||
if total_vram <= 4096:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
lowvram_available = True
|
||||
elif args.novram:
|
||||
set_vram_to = VRAMState.NO_VRAM
|
||||
elif args.highvram:
|
||||
@ -103,18 +104,18 @@ if args.force_fp32:
|
||||
FORCE_FP32 = True
|
||||
|
||||
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
accelerate_enabled = True
|
||||
vram_state = set_vram_to
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||
lowvram_available = False
|
||||
|
||||
total_vram_available_mb = (total_vram - 1024) // 2
|
||||
total_vram_available_mb = int(max(256, total_vram_available_mb))
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@ -199,22 +200,33 @@ def load_model_gpu(model):
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
model.model_patches_to(get_torch_device())
|
||||
torch_dev = get_torch_device()
|
||||
model.model_patches_to(torch_dev)
|
||||
|
||||
vram_set_state = vram_state
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
current_loaded_model = model
|
||||
if vram_state == VRAMState.CPU:
|
||||
|
||||
if vram_set_state == VRAMState.CPU:
|
||||
pass
|
||||
elif vram_state == VRAMState.MPS:
|
||||
elif vram_set_state == VRAMState.MPS:
|
||||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
else:
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
model_accelerated = True
|
||||
@ -377,7 +389,10 @@ def should_use_fp16():
|
||||
|
||||
def soft_empty_cache():
|
||||
global xpu_available
|
||||
if xpu_available:
|
||||
global vram_state
|
||||
if vram_state == VRAMState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
|
||||
35
comfy/sd.py
35
comfy/sd.py
@ -14,6 +14,7 @@ from .t2i_adapter import adapter
|
||||
from . import utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
from . import diffusers_convert
|
||||
|
||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -285,15 +286,29 @@ def model_lora_keys(model, key_map={}):
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, size=0):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = []
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model)
|
||||
n = ModelPatcher(self.model, self.size)
|
||||
n.patches = self.patches[:]
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
@ -504,10 +519,16 @@ class VAE:
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
if ckpt_path is not None:
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
@ -722,7 +743,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
else:
|
||||
@ -739,7 +760,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
if pth:
|
||||
@ -1024,7 +1045,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
}
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": True,
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"attention_resolutions": [
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
@ -46,6 +51,88 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
#normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new):
|
||||
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
|
||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
@ -61,7 +148,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CA_layer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(CA_layer, self).__init__()
|
||||
# global average pooling
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(self.gap(x))
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class Simple_CA_layer(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super(Simple_CA_layer, self).__init__()
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.fc(self.gap(x))
|
||||
|
||||
|
||||
class ECA_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ECA_MaxPool_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_MaxPool_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.max_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSA.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch import einsum, nn
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, length=1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class Conv_PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Gated_Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = int(dim * mult)
|
||||
|
||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=hidden_features * 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
# MBConv
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate=0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce("b c h w -> b c", "mean"),
|
||||
nn.Linear(dim, hidden_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias=False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange("b c -> b c 1 1"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob=0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0.0 or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = (
|
||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||
> self.prob
|
||||
)
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
|
||||
def MBConv(
|
||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(
|
||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||
),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
# nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout=dropout)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# attention related classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
if self.with_pe:
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, "c i j -> (i j) c")
|
||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||
grid, "j ... -> 1 j ..."
|
||||
)
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = (
|
||||
*x.shape,
|
||||
x.device,
|
||||
self.heads,
|
||||
)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# add positional bias
|
||||
if self.with_pe:
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(
|
||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||
)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||
|
||||
|
||||
class Block_Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.ps = window_size
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
# project for queries, keys, values
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||
h=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# attention
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(
|
||||
out,
|
||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||
x=h // self.ps,
|
||||
y=w // self.ps,
|
||||
head=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
)
|
||||
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention_grid(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention_grid, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class OSA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
ffn_bias=True,
|
||||
window_size=8,
|
||||
with_pe=False,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(OSA_Block, self).__init__()
|
||||
|
||||
w = window_size
|
||||
|
||||
self.layer = nn.Sequential(
|
||||
MBConv(
|
||||
channel_num,
|
||||
channel_num,
|
||||
downsample=False,
|
||||
expansion_rate=1,
|
||||
shrinkage_rate=0.25,
|
||||
),
|
||||
Rearrange(
|
||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # block-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
Rearrange(
|
||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # grid-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention_grid(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return out
|
||||
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSAG.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .esa import ESA
|
||||
from .OSA import OSA_Block
|
||||
|
||||
|
||||
class OSAG(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
block_num=4,
|
||||
ffn_bias=False,
|
||||
window_size=0,
|
||||
pe=False,
|
||||
):
|
||||
super(OSAG, self).__init__()
|
||||
|
||||
# print("window_size: %d" % (window_size))
|
||||
# print("with_pe", pe)
|
||||
# print("ffn_bias: %d" % (ffn_bias))
|
||||
|
||||
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||
|
||||
# script_name = "." + block_script_name
|
||||
# package = __import__(script_name, fromlist=True)
|
||||
block_class = OSA_Block # getattr(package, block_class_name)
|
||||
group_list = []
|
||||
for _ in range(block_num):
|
||||
temp_res = block_class(
|
||||
channel_num,
|
||||
bias,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=window_size,
|
||||
with_pe=pe,
|
||||
)
|
||||
group_list.append(temp_res)
|
||||
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||
self.residual_layer = nn.Sequential(*group_list)
|
||||
esa_channel = max(channel_num // 4, 16)
|
||||
self.esa = ESA(esa_channel, channel_num)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.residual_layer(x)
|
||||
out = out + x
|
||||
return self.esa(out)
|
||||
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OmniSR.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .OSAG import OSAG
|
||||
from .pixelshuffle import pixelshuffle_block
|
||||
|
||||
|
||||
class OmniSR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
):
|
||||
super(OmniSR, self).__init__()
|
||||
self.state = state_dict
|
||||
|
||||
bias = True # Fine to assume this for now
|
||||
block_num = 1 # Fine to assume this for now
|
||||
ffn_bias = True
|
||||
pe = True
|
||||
|
||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||
|
||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||
if up_scale - int(up_scale) > 0:
|
||||
print(
|
||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||
)
|
||||
up_scale = int(up_scale)
|
||||
res_num = 0
|
||||
for key in state_dict.keys():
|
||||
if "residual_layer" in key:
|
||||
temp_res_num = int(key.split(".")[1])
|
||||
if temp_res_num > res_num:
|
||||
res_num = temp_res_num
|
||||
res_num = res_num + 1 # zero-indexed
|
||||
|
||||
residual_layer = []
|
||||
self.res_num = res_num
|
||||
|
||||
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||
self.up_scale = up_scale
|
||||
|
||||
for _ in range(res_num):
|
||||
temp_res = OSAG(
|
||||
channel_num=num_feat,
|
||||
bias=bias,
|
||||
block_num=block_num,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=self.window_size,
|
||||
pe=pe,
|
||||
)
|
||||
residual_layer.append(temp_res)
|
||||
self.residual_layer = nn.Sequential(*residual_layer)
|
||||
self.input = nn.Conv2d(
|
||||
in_channels=num_in_ch,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.output = nn.Conv2d(
|
||||
in_channels=num_feat,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||
|
||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||
|
||||
# chaiNNer specific stuff
|
||||
self.model_arch = "OmniSR"
|
||||
self.sub_type = "SR"
|
||||
self.in_nc = num_in_ch
|
||||
self.out_nc = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.scale = up_scale
|
||||
|
||||
self.supports_fp16 = True # TODO: Test this
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = 16
|
||||
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
# import pdb; pdb.set_trace()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
residual = self.input(x)
|
||||
out = self.residual_layer(residual)
|
||||
|
||||
# origin
|
||||
out = torch.add(self.output(out), residual)
|
||||
out = self.up(out)
|
||||
|
||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||
return out
|
||||
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: esa.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
|
||||
def moment(x, dim=(2, 3), k=2):
|
||||
assert len(x.size()) == 4
|
||||
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||
return mk
|
||||
|
||||
|
||||
class ESA(nn.Module):
|
||||
"""
|
||||
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||
are deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||
super(ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
c1 = self.conv2(c1_)
|
||||
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||
c3 = self.conv3(v_max)
|
||||
c3 = F.interpolate(
|
||||
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||
)
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(c3 + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA_LN(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA_LN, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.norm = LayerNorm2d(n_feats)
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.norm(x)
|
||||
c1_ = self.conv1(c1_)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaGuidedFilter, self).__init__()
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=n_feats,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.r = 5
|
||||
|
||||
def box_filter(self, x, r):
|
||||
channel = x.shape[1]
|
||||
kernel_size = 2 * r + 1
|
||||
weight = 1.0 / (kernel_size**2)
|
||||
box_kernel = weight * torch.ones(
|
||||
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||
)
|
||||
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
N = self.box_filter(
|
||||
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||
)
|
||||
|
||||
# epsilon = self.fc(self.gap(x))
|
||||
# epsilon = torch.pow(epsilon, 2)
|
||||
epsilon = 1e-2
|
||||
|
||||
mean_x = self.box_filter(x, self.r) / N
|
||||
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||
|
||||
A = var_x / (var_x + epsilon)
|
||||
b = (1 - A) * mean_x
|
||||
m = A * x + b
|
||||
|
||||
# mean_A = self.box_filter(A, self.r) / N
|
||||
# mean_b = self.box_filter(b, self.r) / N
|
||||
# m = mean_A * x + mean_b
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaConvGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaConvGuidedFilter, self).__init__()
|
||||
f = esa_channels
|
||||
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.vec_conv(x)
|
||||
y = self.hor_conv(y)
|
||||
|
||||
sigma = torch.pow(y, 2)
|
||||
epsilon = self.fc(self.gap(y))
|
||||
|
||||
weight = sigma / (sigma + epsilon)
|
||||
|
||||
m = weight * x + (1 - weight)
|
||||
|
||||
return x * m
|
||||
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: layernorm.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return (
|
||||
gx,
|
||||
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
"""GRN (Global Response Normalization) layer"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: pixelshuffle.py
|
||||
# Created Date: Friday July 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||
):
|
||||
"""
|
||||
Upsample features according to `upscale_factor`.
|
||||
"""
|
||||
padding = kernel_size // 2
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * (upscale_factor**2),
|
||||
kernel_size,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
return nn.Sequential(*[conv, pixel_shuffle])
|
||||
@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
|
||||
self.scale: int = self.get_scale()
|
||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||
|
||||
c2x2 = False
|
||||
if self.state["model.0.weight"].shape[-2] == 2:
|
||||
c2x2 = True
|
||||
self.scale = round(math.sqrt(self.scale / 4))
|
||||
self.model_arch = "ESRGAN-2c2"
|
||||
|
||||
self.supports_fp16 = True
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = None
|
||||
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
|
||||
out_nc=self.num_filters,
|
||||
upscale_factor=3,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
else:
|
||||
upsample_blocks = [
|
||||
upsample_block(
|
||||
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(int(math.log(self.scale, 2)))
|
||||
]
|
||||
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
B.ShortcutBlock(
|
||||
B.sequential(
|
||||
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
|
||||
act_type=self.act,
|
||||
mode="CNA",
|
||||
plus=self.plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(self.num_blocks)
|
||||
],
|
||||
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
|
||||
norm_type=self.norm,
|
||||
act_type=None,
|
||||
mode=self.mode,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
),
|
||||
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
# hr_conv1
|
||||
B.conv_block(
|
||||
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -141,6 +141,19 @@ def sequential(*args):
|
||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||
|
||||
|
||||
# 2x2x2 Conv Block
|
||||
def conv_block_2c2(
|
||||
in_nc,
|
||||
out_nc,
|
||||
act_type="relu",
|
||||
):
|
||||
return sequential(
|
||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||
act(act_type) if act_type else None,
|
||||
)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
@ -153,12 +166,17 @@ def conv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type: str | None = "relu",
|
||||
mode: ConvMode = "CNA",
|
||||
c2x2=False,
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
|
||||
if c2x2:
|
||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||
|
||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||
@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
||||
_convtype="Conv2D",
|
||||
_spectral_norm=False,
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(
|
||||
@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB2 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB3 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nf + gc,
|
||||
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nf + 2 * gc,
|
||||
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nf + 3 * gc,
|
||||
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=last_act,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -499,6 +527,7 @@ def upconv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
mode="nearest",
|
||||
c2x2=False,
|
||||
):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
@ -512,5 +541,6 @@ def upconv_block(
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
state_dict = state_dict["params"]
|
||||
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
# SRVGGNet Real-ESRGAN (v2)
|
||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||
model = RealESRGANv2(state_dict)
|
||||
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
# MAT
|
||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||
model = MAT(state_dict)
|
||||
# Omni-SR
|
||||
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||
model = OmniSR(state_dict)
|
||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
||||
from .architecture.Swin2SR import Swin2SR
|
||||
from .architecture.SwinIR import SwinIR
|
||||
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
|
||||
PyTorchSRModel = Union[
|
||||
RealESRGANv2,
|
||||
SPSR,
|
||||
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
|
||||
SwinIR,
|
||||
Swin2SR,
|
||||
HAT,
|
||||
OmniSR,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ class MaskComposite:
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract"],),
|
||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,6 +193,12 @@ class MaskComposite:
|
||||
output[top:bottom, left:right] = destination_portion + source_portion
|
||||
elif operation == "subtract":
|
||||
output[top:bottom, left:right] = destination_portion - source_portion
|
||||
elif operation == "and":
|
||||
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "or":
|
||||
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "xor":
|
||||
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
|
||||
440
execution.py
440
execution.py
@ -102,13 +102,21 @@ def get_output_data(obj, input_data_all):
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
return None
|
||||
elif isinstance(x, (int, float, bool, str)):
|
||||
return x
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -117,22 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
input_data_formatted = {}
|
||||
if input_data_all is not None:
|
||||
input_data_formatted = {}
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
}
|
||||
return (False, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -210,6 +260,48 @@ class PromptExecutor:
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
node_id = error["node_id"]
|
||||
class_type = prompt[node_id]["class_type"]
|
||||
|
||||
# First, send back the status to the frontend depending
|
||||
# on the exception type
|
||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
}
|
||||
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
||||
else:
|
||||
if self.server.client_id is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
}
|
||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
@ -244,42 +336,30 @@ class PromptExecutor:
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in list(execute_outputs):
|
||||
to_execute += [(0, x)]
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
x = to_execute.pop(0)[-1]
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
except Exception as e:
|
||||
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||
print("Processing interrupted")
|
||||
else:
|
||||
message = str(traceback.format_exc())
|
||||
print(message)
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
finally:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
@ -297,57 +377,202 @@ def validate_inputs(prompt, item, validated):
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
for x in required_inputs:
|
||||
if x not in inputs:
|
||||
return (False, "Required input is missing. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
return (False, "Bad Input. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "bad_linked_input",
|
||||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] == False:
|
||||
validated[o_id] = r
|
||||
return r
|
||||
received_type = r[val[1]]
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_type": received_type,
|
||||
"linked_node": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
continue
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_inner_validation",
|
||||
"message": "Exception when validating inner node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"linked_node": val
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
continue
|
||||
else:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
try:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
"message": f"Failed to convert an input value to a {type_input} value",
|
||||
"details": f"{x}, {val}, {ex}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
"exception_message": str(ex)
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
return (False, "{}, {}".format(class_type, r))
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
# Don't send back gigantic lists like if they're lots of
|
||||
# scanned model filepaths
|
||||
if len(type_input) > 20:
|
||||
list_info = f"(list of length {len(type_input)})"
|
||||
input_config = None
|
||||
else:
|
||||
list_info = str(type_input)
|
||||
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
else:
|
||||
ret = (True, [], unique_id)
|
||||
|
||||
ret = (True, "")
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
def full_type_name(klass):
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
@ -356,35 +581,86 @@ def validate_prompt(prompt):
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
return (False, "Prompt has no outputs")
|
||||
error = {
|
||||
"type": "prompt_no_outputs",
|
||||
"message": "Prompt has no outputs",
|
||||
"details": "",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
node_errors = {}
|
||||
validated = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reason = ""
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_validation",
|
||||
"message": "Exception when validating node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb)
|
||||
}
|
||||
}]
|
||||
validated[o] = (False, reasons, o)
|
||||
|
||||
if valid == True:
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||
print("output will be ignored")
|
||||
errors += [(o, reason)]
|
||||
print(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
print("* (prompt):")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
reasons = result[1]
|
||||
# If a node upstream has errors, the nodes downstream will also
|
||||
# be reported as invalid, but there will be no errors attached.
|
||||
# So don't return those nodes as having errors in the response.
|
||||
if valid is not True and len(reasons) > 0:
|
||||
if node_id not in node_errors:
|
||||
class_type = prompt[node_id]['class_type']
|
||||
node_errors[node_id] = {
|
||||
"errors": reasons,
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
print(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||
errors_list = []
|
||||
for o, errors in errors:
|
||||
for error in errors:
|
||||
errors_list.append(f"{error['message']}: {error['details']}")
|
||||
errors_list = "\n".join(errors_list)
|
||||
|
||||
return (True, "", list(good_outputs))
|
||||
error = {
|
||||
"type": "prompt_outputs_failed_validation",
|
||||
"message": "Prompt outputs failed validation",
|
||||
"details": errors_list,
|
||||
"extra_info": {}
|
||||
}
|
||||
|
||||
return (False, error, list(good_outputs), node_errors)
|
||||
|
||||
return (True, None, list(good_outputs), node_errors)
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
|
||||
@ -1,14 +1,7 @@
|
||||
import os
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions.add('.safetensors')
|
||||
supported_pt_extensions.add('.safetensors')
|
||||
except:
|
||||
print("Could not import safetensors, safetensors support disabled.")
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
@ -38,6 +31,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
os.makedirs(input_directory)
|
||||
|
||||
@ -118,12 +113,18 @@ def get_folder_paths(folder_name):
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory):
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
result = []
|
||||
dirs = {directory: os.path.getmtime(directory)}
|
||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||
for filepath in file:
|
||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||
return result
|
||||
for d in subdir:
|
||||
path = os.path.join(root, d)
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
@ -132,20 +133,55 @@ def filter_files_extensions(files, extensions):
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
global folder_names_and_paths
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||
for x in folders[0]:
|
||||
full_path = os.path.join(x, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
def get_filename_list_(folder_name):
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
output_folders = {}
|
||||
for x in folders[0]:
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
files, folders_all = recursive_search(x)
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders)
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
if x not in out[1]:
|
||||
return None
|
||||
|
||||
return out
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
return list(out[0])
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
|
||||
36
nodes.py
36
nodes.py
@ -17,7 +17,7 @@ import safetensors.torch
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
import comfy.diffusers_convert
|
||||
import comfy.diffusers_load
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
import comfy.sd
|
||||
@ -377,7 +377,7 @@ class DiffusersLoader:
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
@ -426,6 +426,9 @@ class LoraLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
@ -507,6 +510,9 @@ class ControlNetApply:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
for t in conditioning:
|
||||
@ -613,6 +619,9 @@ class unCLIPConditioning:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
o = t[1].copy()
|
||||
@ -749,7 +758,7 @@ class RepeatLatentBatch:
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@classmethod
|
||||
@ -768,6 +777,25 @@ class LatentUpscale:
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
return (s,)
|
||||
|
||||
class LatentUpscaleBy:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
|
||||
"scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, scale_by):
|
||||
s = samples.copy()
|
||||
width = round(samples["samples"].shape[3] * scale_by)
|
||||
height = round(samples["samples"].shape[2] * scale_by)
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||
return (s,)
|
||||
|
||||
class LatentRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1244,6 +1272,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAELoader": VAELoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentUpscaleBy": LatentUpscaleBy,
|
||||
"LatentFromBatch": LatentFromBatch,
|
||||
"RepeatLatentBatch": RepeatLatentBatch,
|
||||
"SaveImage": SaveImage,
|
||||
@ -1322,6 +1351,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentCrop": "Crop Latent",
|
||||
"EmptyLatentImage": "Empty Latent Image",
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentUpscaleBy": "Upscale Latent By",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
|
||||
36
server.py
36
server.py
@ -22,7 +22,7 @@ except ImportError:
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
|
||||
import comfy.utils
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
@ -257,6 +257,29 @@ class PromptServer():
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/view_metadata/{folder_name}")
|
||||
async def view_metadata(request):
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
if not filename.endswith(".safetensors"):
|
||||
return web.Response(status=404)
|
||||
|
||||
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||||
if safetensors_path is None:
|
||||
return web.Response(status=404)
|
||||
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
@ -272,6 +295,11 @@ class PromptServer():
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
@ -333,12 +361,12 @@ class PromptServer():
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
return web.json_response({"prompt_id": prompt_id, "number": number})
|
||||
else:
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1]}, status=400)
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt"}, status=400)
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
|
||||
@ -174,7 +174,7 @@ const els = {}
|
||||
// const ctxMenu = LiteGraph.ContextMenu;
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
init() {
|
||||
addCustomNodeDefs(node_defs) {
|
||||
const sortObjectKeys = (unordered) => {
|
||||
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||
obj[key] = unordered[key];
|
||||
@ -182,10 +182,10 @@ app.registerExtension({
|
||||
}, {});
|
||||
};
|
||||
|
||||
const getSlotTypes = async () => {
|
||||
function getSlotTypes() {
|
||||
var types = [];
|
||||
|
||||
const defs = await api.getNodeDefs();
|
||||
const defs = node_defs;
|
||||
for (const nodeId in defs) {
|
||||
const nodeData = defs[nodeId];
|
||||
|
||||
@ -212,8 +212,8 @@ app.registerExtension({
|
||||
return types;
|
||||
};
|
||||
|
||||
const completeColorPalette = async (colorPalette) => {
|
||||
var types = await getSlotTypes();
|
||||
function completeColorPalette(colorPalette) {
|
||||
var types = getSlotTypes();
|
||||
|
||||
for (const type of types) {
|
||||
if (!colorPalette.colors.node_slot[type]) {
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
window.graph = app.graph;
|
||||
</script>
|
||||
</head>
|
||||
<body></body>
|
||||
<body class="litegraph"></body>
|
||||
</html>
|
||||
|
||||
@ -88,6 +88,12 @@ class ComfyApi extends EventTarget {
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
|
||||
@ -771,16 +771,27 @@ export class ComfyApp {
|
||||
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
|
||||
const res = origDrawNodeShape.apply(this, arguments);
|
||||
|
||||
const nodeErrors = self.lastPromptError?.node_errors[node.id];
|
||||
|
||||
let color = null;
|
||||
let lineWidth = 1;
|
||||
if (node.id === +self.runningNodeId) {
|
||||
color = "#0f0";
|
||||
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
|
||||
color = "dodgerblue";
|
||||
}
|
||||
else if (self.lastPromptError != null && nodeErrors?.errors) {
|
||||
color = "red";
|
||||
lineWidth = 2;
|
||||
}
|
||||
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
|
||||
color = "#f0f";
|
||||
lineWidth = 2;
|
||||
}
|
||||
|
||||
if (color) {
|
||||
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
|
||||
ctx.lineWidth = 1;
|
||||
ctx.lineWidth = lineWidth;
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.beginPath();
|
||||
if (shape == LiteGraph.BOX_SHAPE)
|
||||
@ -807,11 +818,28 @@ export class ComfyApp {
|
||||
ctx.stroke();
|
||||
ctx.strokeStyle = fgcolor;
|
||||
ctx.globalAlpha = 1;
|
||||
}
|
||||
|
||||
if (self.progress) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
if (self.progress && node.id === +self.runningNodeId) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
}
|
||||
|
||||
// Highlight inputs that failed validation
|
||||
if (nodeErrors) {
|
||||
ctx.lineWidth = 2;
|
||||
ctx.strokeStyle = "red";
|
||||
for (const error of nodeErrors.errors) {
|
||||
if (error.extra_info && error.extra_info.input_name) {
|
||||
const inputIndex = node.findInputSlot(error.extra_info.input_name)
|
||||
if (inputIndex !== -1) {
|
||||
let pos = node.getConnectionPos(true, inputIndex);
|
||||
ctx.beginPath();
|
||||
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
|
||||
ctx.stroke();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -869,6 +897,17 @@ export class ComfyApp {
|
||||
}
|
||||
});
|
||||
|
||||
api.addEventListener("execution_start", ({ detail }) => {
|
||||
this.lastExecutionError = null
|
||||
});
|
||||
|
||||
api.addEventListener("execution_error", ({ detail }) => {
|
||||
this.lastExecutionError = detail;
|
||||
const formattedError = this.#formatExecutionError(detail);
|
||||
this.ui.dialog.show(formattedError);
|
||||
this.canvas.draw(true, true);
|
||||
});
|
||||
|
||||
api.init();
|
||||
}
|
||||
|
||||
@ -971,6 +1010,11 @@ export class ComfyApp {
|
||||
const app = this;
|
||||
// Load node definitions from the backend
|
||||
const defs = await api.getNodeDefs();
|
||||
await this.registerNodesFromDefs(defs);
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
async registerNodesFromDefs(defs) {
|
||||
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
||||
|
||||
// Generate list of known widgets
|
||||
@ -1043,8 +1087,6 @@ export class ComfyApp {
|
||||
LiteGraph.registerNodeType(nodeId, node);
|
||||
node.category = nodeData.category;
|
||||
}
|
||||
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1243,6 +1285,43 @@ export class ComfyApp {
|
||||
return { workflow, output };
|
||||
}
|
||||
|
||||
#formatPromptError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
else if (typeof error === "string") {
|
||||
return error;
|
||||
}
|
||||
else if (error.stack && error.message) {
|
||||
return error.toString()
|
||||
}
|
||||
else if (error.response) {
|
||||
let message = error.response.error.message;
|
||||
if (error.response.error.details)
|
||||
message += ": " + error.response.error.details;
|
||||
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
|
||||
message += "\n" + nodeError.class_type + ":"
|
||||
for (const errorReason of nodeError.errors) {
|
||||
message += "\n - " + errorReason.message + ": " + errorReason.details
|
||||
}
|
||||
}
|
||||
return message
|
||||
}
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
#formatExecutionError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
const traceback = error.traceback.join("")
|
||||
const nodeId = error.node_id
|
||||
const nodeType = error.node_type
|
||||
|
||||
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
|
||||
}
|
||||
|
||||
async queuePrompt(number, batchCount = 1) {
|
||||
this.#queueItems.push({ number, batchCount });
|
||||
|
||||
@ -1250,8 +1329,10 @@ export class ComfyApp {
|
||||
if (this.#processingQueue) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
this.#processingQueue = true;
|
||||
this.lastPromptError = null;
|
||||
|
||||
try {
|
||||
while (this.#queueItems.length) {
|
||||
({ number, batchCount } = this.#queueItems.pop());
|
||||
@ -1278,7 +1359,12 @@ export class ComfyApp {
|
||||
try {
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.ui.dialog.show(error.response.error || error.toString());
|
||||
const formattedError = this.#formatPromptError(error)
|
||||
this.ui.dialog.show(formattedError);
|
||||
if (error.response) {
|
||||
this.lastPromptError = error.response;
|
||||
this.canvas.draw(true, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1357,6 +1443,11 @@ export class ComfyApp {
|
||||
|
||||
const def = defs[node.type];
|
||||
|
||||
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
|
||||
// and additional work is needed to consider the primitive logic in the refresh logic.
|
||||
if(!def)
|
||||
continue;
|
||||
|
||||
for(const widgetNum in node.widgets) {
|
||||
const widget = node.widgets[widgetNum]
|
||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||
@ -1376,6 +1467,8 @@ export class ComfyApp {
|
||||
*/
|
||||
clean() {
|
||||
this.nodeOutputs = {};
|
||||
this.lastPromptError = null;
|
||||
this.lastExecutionError = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) {
|
||||
const embeddings = await api.getEmbeddings();
|
||||
const opts = parameters
|
||||
.substr(p)
|
||||
.split("\n")[1]
|
||||
.split(",")
|
||||
.reduce((p, n) => {
|
||||
const s = n.split(":");
|
||||
|
||||
@ -289,6 +289,11 @@ button.comfy-queue-btn {
|
||||
|
||||
/* Context menu */
|
||||
|
||||
.litegraph .dialog {
|
||||
z-index: 1;
|
||||
font-family: Arial;
|
||||
}
|
||||
|
||||
.litegraph .litemenu-entry.has_submenu {
|
||||
position: relative;
|
||||
padding-right: 20px;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user