mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Merge branch 'comfyanonymous:master' into feature/blockweights
This commit is contained in:
commit
23332731bd
@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
|
|||||||
else:
|
else:
|
||||||
raise AssertionError('Unknown merge analysis result')
|
raise AssertionError('Unknown merge analysis result')
|
||||||
|
|
||||||
|
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
||||||
repo = pygit2.Repository(str(sys.argv[1]))
|
repo = pygit2.Repository(str(sys.argv[1]))
|
||||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -30,6 +30,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
persist-credentials: false
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd ..
|
cd ..
|
||||||
|
|||||||
@ -17,6 +17,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.11.3'
|
python-version: '3.11.3'
|
||||||
|
|||||||
@ -1,14 +1,5 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
|
||||||
import os.path as osp
|
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
|
|
||||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||||
|
|
||||||
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
|
|||||||
return text_enc_dict
|
return text_enc_dict
|
||||||
|
|
||||||
|
|
||||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
|
||||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
|
||||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
|
||||||
|
|
||||||
# magic
|
|
||||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
|
||||||
if 'prediction_type' in diffusers_scheduler_conf:
|
|
||||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
|
||||||
|
|
||||||
if v2:
|
|
||||||
if v_pred:
|
|
||||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
|
||||||
else:
|
|
||||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
|
||||||
else:
|
|
||||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
|
||||||
|
|
||||||
with open(config_path, 'r') as stream:
|
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
model_config_params = config['model']['params']
|
|
||||||
clip_config = model_config_params['cond_stage_config']
|
|
||||||
scale_factor = model_config_params['scale_factor']
|
|
||||||
vae_config = model_config_params['first_stage_config']
|
|
||||||
vae_config['scale_factor'] = scale_factor
|
|
||||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
|
||||||
|
|
||||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
|
||||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
|
||||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
|
||||||
|
|
||||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
|
||||||
if osp.exists(unet_path):
|
|
||||||
unet_state_dict = load_file(unet_path, device="cpu")
|
|
||||||
else:
|
|
||||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
|
||||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
|
||||||
|
|
||||||
if osp.exists(vae_path):
|
|
||||||
vae_state_dict = load_file(vae_path, device="cpu")
|
|
||||||
else:
|
|
||||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
|
||||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
|
||||||
|
|
||||||
if osp.exists(text_enc_path):
|
|
||||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
|
||||||
else:
|
|
||||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
|
||||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
|
||||||
|
|
||||||
# Convert the UNet model
|
|
||||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
|
||||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
|
||||||
|
|
||||||
# Convert the VAE model
|
|
||||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
|
||||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
|
||||||
|
|
||||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
|
||||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
|
||||||
|
|
||||||
if is_v20_model:
|
|
||||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
|
||||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
|
||||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
else:
|
|
||||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
|
||||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
|
||||||
|
|
||||||
# Put together new checkpoint
|
|
||||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
|
||||||
|
|
||||||
clip = None
|
|
||||||
vae = None
|
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
w = WeightsLoader()
|
|
||||||
load_state_dict_to = []
|
|
||||||
if output_vae:
|
|
||||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
|
||||||
w.first_stage_model = vae.first_stage_model
|
|
||||||
load_state_dict_to = [w]
|
|
||||||
|
|
||||||
if output_clip:
|
|
||||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
|
||||||
load_state_dict_to = [w]
|
|
||||||
|
|
||||||
model = instantiate_from_config(config["model"])
|
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
|
||||||
|
|
||||||
if fp16:
|
|
||||||
model = model.half()
|
|
||||||
|
|
||||||
return ModelPatcher(model), clip, vae
|
|
||||||
|
|||||||
111
comfy/diffusers_load.py
Normal file
111
comfy/diffusers_load.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy.ldm.util import instantiate_from_config
|
||||||
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
import diffusers_convert
|
||||||
|
|
||||||
|
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||||
|
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||||
|
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||||
|
|
||||||
|
# magic
|
||||||
|
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||||
|
if 'prediction_type' in diffusers_scheduler_conf:
|
||||||
|
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||||
|
|
||||||
|
if v2:
|
||||||
|
if v_pred:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||||
|
else:
|
||||||
|
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||||
|
|
||||||
|
with open(config_path, 'r') as stream:
|
||||||
|
config = yaml.safe_load(stream)
|
||||||
|
|
||||||
|
model_config_params = config['model']['params']
|
||||||
|
clip_config = model_config_params['cond_stage_config']
|
||||||
|
scale_factor = model_config_params['scale_factor']
|
||||||
|
vae_config = model_config_params['first_stage_config']
|
||||||
|
vae_config['scale_factor'] = scale_factor
|
||||||
|
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||||
|
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||||
|
|
||||||
|
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||||
|
if osp.exists(unet_path):
|
||||||
|
unet_state_dict = load_file(unet_path, device="cpu")
|
||||||
|
else:
|
||||||
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||||
|
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(vae_path):
|
||||||
|
vae_state_dict = load_file(vae_path, device="cpu")
|
||||||
|
else:
|
||||||
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||||
|
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||||
|
|
||||||
|
if osp.exists(text_enc_path):
|
||||||
|
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||||
|
else:
|
||||||
|
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||||
|
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Convert the UNet model
|
||||||
|
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
||||||
|
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||||
|
|
||||||
|
# Convert the VAE model
|
||||||
|
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
||||||
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||||
|
|
||||||
|
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||||
|
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||||
|
|
||||||
|
if is_v20_model:
|
||||||
|
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||||
|
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
else:
|
||||||
|
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
||||||
|
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||||
|
|
||||||
|
# Put together new checkpoint
|
||||||
|
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||||
|
|
||||||
|
clip = None
|
||||||
|
vae = None
|
||||||
|
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
w = WeightsLoader()
|
||||||
|
load_state_dict_to = []
|
||||||
|
if output_vae:
|
||||||
|
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||||
|
w.first_stage_model = vae.first_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
if output_clip:
|
||||||
|
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||||
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
model = instantiate_from_config(config["model"])
|
||||||
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
|
if fp16:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
return ModelPatcher(model), clip, vae
|
||||||
@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
"""DPM-Solver++(2M) SDE."""
|
||||||
|
|
||||||
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
h_last = None
|
||||||
|
h = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2M) SDE
|
||||||
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = s - t
|
||||||
|
eta_h = eta * h
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||||
|
|
||||||
|
if old_denoised is not None:
|
||||||
|
r = h_last / h
|
||||||
|
if solver_type == 'heun':
|
||||||
|
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||||
|
elif solver_type == 'midpoint':
|
||||||
|
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
|
old_denoised = denoised
|
||||||
|
h_last = h
|
||||||
|
return x
|
||||||
|
|||||||
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
return x+h
|
return x+h
|
||||||
|
|
||||||
|
def slice_attention(q, k, v):
|
||||||
|
r1 = torch.zeros_like(k, device=q.device)
|
||||||
|
scale = (int(q.shape[-1])**(-0.5))
|
||||||
|
|
||||||
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||||
|
|
||||||
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
|
del s2
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
steps *= 2
|
||||||
|
if steps > 128:
|
||||||
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
|
||||||
|
return r1
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
b,c,h,w = q.shape
|
||||||
scale = (int(c)**(-0.5))
|
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
q = q.reshape(b,c,h*w)
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
v = v.reshape(b,c,h*w)
|
v = v.reshape(b,c,h*w)
|
||||||
|
|
||||||
r1 = torch.zeros_like(k, device=q.device)
|
r1 = slice_attention(q, k, v)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
|
||||||
|
|
||||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
|
||||||
del s2
|
|
||||||
break
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
steps *= 2
|
|
||||||
if steps > 128:
|
|
||||||
raise e
|
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
|
||||||
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
h_ = r1.reshape(b,c,h,w)
|
||||||
del r1
|
del r1
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
B, C, H, W = q.shape
|
||||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
.reshape(B, t.shape[1], 1, C)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B * 1, t.shape[1], C)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
|
|
||||||
out = (
|
try:
|
||||||
out.unsqueeze(0)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
.reshape(B, 1, out.shape[1], C)
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
.permute(0, 2, 1, 3)
|
except model_management.OOM_EXCEPTION as e:
|
||||||
.reshape(B, out.shape[1], C)
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
"""
|
"""
|
||||||
B, N, _ = metric.shape
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
if r <= 0:
|
if r <= 0 or w == 1 or h == 1:
|
||||||
return do_nothing, do_nothing
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|||||||
@ -1,23 +1,29 @@
|
|||||||
import psutil
|
import psutil
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
CPU = 0
|
DISABLED = 0
|
||||||
NO_VRAM = 1
|
NO_VRAM = 1
|
||||||
LOW_VRAM = 2
|
LOW_VRAM = 2
|
||||||
NORMAL_VRAM = 3
|
NORMAL_VRAM = 3
|
||||||
HIGH_VRAM = 4
|
HIGH_VRAM = 4
|
||||||
MPS = 5
|
SHARED = 5
|
||||||
|
|
||||||
|
class CPUState(Enum):
|
||||||
|
GPU = 0
|
||||||
|
CPU = 1
|
||||||
|
MPS = 2
|
||||||
|
|
||||||
# Determine VRAM State
|
# Determine VRAM State
|
||||||
vram_state = VRAMState.NORMAL_VRAM
|
vram_state = VRAMState.NORMAL_VRAM
|
||||||
set_vram_to = VRAMState.NORMAL_VRAM
|
set_vram_to = VRAMState.NORMAL_VRAM
|
||||||
|
cpu_state = CPUState.GPU
|
||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
total_vram_available_mb = -1
|
|
||||||
|
|
||||||
accelerate_enabled = False
|
lowvram_available = True
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
@ -31,30 +37,80 @@ if args.directml is not None:
|
|||||||
directml_device = torch_directml.device(device_index)
|
directml_device = torch_directml.device(device_index)
|
||||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||||
# torch_directml.disable_tiled_resources(True)
|
# torch_directml.disable_tiled_resources(True)
|
||||||
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import intel_extension_for_pytorch as ipex
|
||||||
if directml_enabled:
|
if torch.xpu.is_available():
|
||||||
total_vram = 4097 #TODO
|
xpu_available = True
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
if torch.xpu.is_available():
|
|
||||||
xpu_available = True
|
|
||||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
|
||||||
except:
|
|
||||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
|
||||||
if not args.normalvram and not args.cpu:
|
|
||||||
if total_vram <= 4096:
|
|
||||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
|
||||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
|
||||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
|
||||||
vram_state = VRAMState.HIGH_VRAM
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
cpu_state = CPUState.MPS
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
global cpu_state
|
||||||
|
if directml_enabled:
|
||||||
|
global directml_device
|
||||||
|
return directml_device
|
||||||
|
if cpu_state == CPUState.MPS:
|
||||||
|
return torch.device("mps")
|
||||||
|
if cpu_state == CPUState.CPU:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.device(torch.cuda.current_device())
|
||||||
|
|
||||||
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if dev is None:
|
||||||
|
dev = get_torch_device()
|
||||||
|
|
||||||
|
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||||
|
mem_total = psutil.virtual_memory().total
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
else:
|
||||||
|
if directml_enabled:
|
||||||
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
elif xpu_available:
|
||||||
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
mem_total_torch = mem_total
|
||||||
|
else:
|
||||||
|
stats = torch.cuda.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
||||||
|
mem_total_torch = mem_reserved
|
||||||
|
mem_total = mem_total_cuda
|
||||||
|
|
||||||
|
if torch_total_too:
|
||||||
|
return (mem_total, mem_total_torch)
|
||||||
|
else:
|
||||||
|
return mem_total
|
||||||
|
|
||||||
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
|
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
|
if not args.normalvram and not args.cpu:
|
||||||
|
if lowvram_available and total_vram <= 4096:
|
||||||
|
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||||
|
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||||
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
except:
|
except:
|
||||||
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
|||||||
|
|
||||||
if args.lowvram:
|
if args.lowvram:
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
|
lowvram_available = True
|
||||||
elif args.novram:
|
elif args.novram:
|
||||||
set_vram_to = VRAMState.NO_VRAM
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
elif args.highvram:
|
elif args.highvram:
|
||||||
@ -102,32 +159,42 @@ if args.force_fp32:
|
|||||||
print("Forcing FP32, if this improves things please report it.")
|
print("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
|
if lowvram_available:
|
||||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
|
||||||
try:
|
try:
|
||||||
import accelerate
|
import accelerate
|
||||||
accelerate_enabled = True
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
vram_state = set_vram_to
|
vram_state = set_vram_to
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||||
|
lowvram_available = False
|
||||||
|
|
||||||
total_vram_available_mb = (total_vram - 1024) // 2
|
|
||||||
total_vram_available_mb = int(max(256, total_vram_available_mb))
|
|
||||||
|
|
||||||
try:
|
if cpu_state != CPUState.GPU:
|
||||||
if torch.backends.mps.is_available():
|
vram_state = VRAMState.DISABLED
|
||||||
vram_state = VRAMState.MPS
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if args.cpu:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.CPU
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_device_name(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
if device.type == "cuda":
|
||||||
|
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
else:
|
||||||
|
return "{}".format(device.type)
|
||||||
|
else:
|
||||||
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Device:", get_torch_device_name(get_torch_device()))
|
||||||
|
except:
|
||||||
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
current_gpu_controlnets = []
|
current_gpu_controlnets = []
|
||||||
|
|
||||||
@ -173,22 +240,29 @@ def load_model_gpu(model):
|
|||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
model.model_patches_to(get_torch_device())
|
torch_dev = get_torch_device()
|
||||||
|
model.model_patches_to(torch_dev)
|
||||||
|
|
||||||
|
vram_set_state = vram_state
|
||||||
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
|
model_size = model.model_size()
|
||||||
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
|
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
|
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
||||||
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == VRAMState.CPU:
|
|
||||||
|
if vram_set_state == VRAMState.DISABLED:
|
||||||
pass
|
pass
|
||||||
elif vram_state == VRAMState.MPS:
|
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||||
mps_device = torch.device("mps")
|
|
||||||
real_model.to(mps_device)
|
|
||||||
pass
|
|
||||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
real_model.to(get_torch_device())
|
real_model.to(get_torch_device())
|
||||||
else:
|
else:
|
||||||
if vram_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||||
elif vram_state == VRAMState.LOW_VRAM:
|
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||||
|
|
||||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
@ -197,7 +271,7 @@ def load_model_gpu(model):
|
|||||||
def load_controlnet_gpu(control_models):
|
def load_controlnet_gpu(control_models):
|
||||||
global current_gpu_controlnets
|
global current_gpu_controlnets
|
||||||
global vram_state
|
global vram_state
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.DISABLED:
|
||||||
return
|
return
|
||||||
|
|
||||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||||
@ -233,22 +307,6 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
|
||||||
if directml_enabled:
|
|
||||||
global directml_device
|
|
||||||
return directml_device
|
|
||||||
if vram_state == VRAMState.MPS:
|
|
||||||
return torch.device("mps")
|
|
||||||
if vram_state == VRAMState.CPU:
|
|
||||||
return torch.device("cpu")
|
|
||||||
else:
|
|
||||||
if xpu_available:
|
|
||||||
return torch.device("xpu")
|
|
||||||
else:
|
|
||||||
return torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
@ -258,7 +316,8 @@ def get_autocast_device(dev):
|
|||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if vram_state == VRAMState.CPU:
|
global cpu_state
|
||||||
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
if xpu_available:
|
if xpu_available:
|
||||||
return False
|
return False
|
||||||
@ -330,12 +389,12 @@ def maximum_batch_area():
|
|||||||
return int(max(area, 0))
|
return int(max(area, 0))
|
||||||
|
|
||||||
def cpu_mode():
|
def cpu_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.CPU
|
return cpu_state == CPUState.CPU
|
||||||
|
|
||||||
def mps_mode():
|
def mps_mode():
|
||||||
global vram_state
|
global cpu_state
|
||||||
return vram_state == VRAMState.MPS
|
return cpu_state == CPUState.MPS
|
||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
@ -367,7 +426,10 @@ def should_use_fp16():
|
|||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
if xpu_available:
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.MPS:
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
|
|||||||
@ -2,17 +2,26 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def prepare_noise(latent_image, seed, skip=0):
|
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||||
"""
|
"""
|
||||||
creates random noise given a latent image and a seed.
|
creates random noise given a latent image and a seed.
|
||||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||||
"""
|
"""
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
for _ in range(skip):
|
if noise_inds is None:
|
||||||
|
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
|
||||||
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||||
|
noises = []
|
||||||
|
for i in range(unique_inds[-1]+1):
|
||||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
if i in unique_inds:
|
||||||
return noise
|
noises.append(noise)
|
||||||
|
noises = [noises[i] for i in inverse]
|
||||||
|
noises = torch.cat(noises, axis=0)
|
||||||
|
return noises
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
|||||||
@ -6,6 +6,10 @@ import contextlib
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
|
import math
|
||||||
|
|
||||||
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
|
return abs(a*b) // math.gcd(a, b)
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if c1.keys() != c2.keys():
|
if c1.keys() != c2.keys():
|
||||||
return False
|
return False
|
||||||
if 'c_crossattn' in c1:
|
if 'c_crossattn' in c1:
|
||||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
s1 = c1['c_crossattn'].shape
|
||||||
return False
|
s2 = c2['c_crossattn'].shape
|
||||||
|
if s1 != s2:
|
||||||
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
|
return False
|
||||||
|
|
||||||
|
mult_min = lcm(s1[1], s2[1])
|
||||||
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
|
return False
|
||||||
if 'c_concat' in c1:
|
if 'c_concat' in c1:
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||||
return False
|
return False
|
||||||
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
c_crossattn = []
|
c_crossattn = []
|
||||||
c_concat = []
|
c_concat = []
|
||||||
c_adm = []
|
c_adm = []
|
||||||
|
crossattn_max_len = 0
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
if 'c_crossattn' in x:
|
if 'c_crossattn' in x:
|
||||||
c_crossattn.append(x['c_crossattn'])
|
c = x['c_crossattn']
|
||||||
|
if crossattn_max_len == 0:
|
||||||
|
crossattn_max_len = c.shape[1]
|
||||||
|
else:
|
||||||
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||||
|
c_crossattn.append(c)
|
||||||
if 'c_concat' in x:
|
if 'c_concat' in x:
|
||||||
c_concat.append(x['c_concat'])
|
c_concat.append(x['c_concat'])
|
||||||
if 'c_adm' in x:
|
if 'c_adm' in x:
|
||||||
c_adm.append(x['c_adm'])
|
c_adm.append(x['c_adm'])
|
||||||
out = {}
|
out = {}
|
||||||
if len(c_crossattn) > 0:
|
c_crossattn_out = []
|
||||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
for c in c_crossattn:
|
||||||
|
if c.shape[1] < crossattn_max_len:
|
||||||
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
|
c_crossattn_out.append(c)
|
||||||
|
|
||||||
|
if len(c_crossattn_out) > 0:
|
||||||
|
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||||
if len(c_concat) > 0:
|
if len(c_concat) > 0:
|
||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
if len(c_adm) > 0:
|
if len(c_adm) > 0:
|
||||||
@ -362,19 +386,8 @@ def resolve_cond_masks(conditions, h, w, device):
|
|||||||
else:
|
else:
|
||||||
box = boxes[0]
|
box = boxes[0]
|
||||||
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||||
# Make sure the height and width are divisible by 8
|
H = max(8, H)
|
||||||
if X % 8 != 0:
|
W = max(8, W)
|
||||||
newx = X // 8 * 8
|
|
||||||
W = W + (X - newx)
|
|
||||||
X = newx
|
|
||||||
if Y % 8 != 0:
|
|
||||||
newy = Y // 8 * 8
|
|
||||||
H = H + (Y - newy)
|
|
||||||
Y = newy
|
|
||||||
if H % 8 != 0:
|
|
||||||
H = H + (8 - (H % 8))
|
|
||||||
if W % 8 != 0:
|
|
||||||
W = W + (8 - (W % 8))
|
|
||||||
area = (int(H), int(W), int(Y), int(X))
|
area = (int(H), int(W), int(Y), int(X))
|
||||||
modified['area'] = area
|
modified['area'] = area
|
||||||
|
|
||||||
@ -482,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||||||
|
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -519,6 +532,8 @@ class KSampler:
|
|||||||
|
|
||||||
if self.scheduler == "karras":
|
if self.scheduler == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
|
elif self.scheduler == "exponential":
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
elif self.scheduler == "normal":
|
elif self.scheduler == "normal":
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
elif self.scheduler == "simple":
|
elif self.scheduler == "simple":
|
||||||
|
|||||||
66
comfy/sd.py
66
comfy/sd.py
@ -14,6 +14,7 @@ from .t2i_adapter import adapter
|
|||||||
from . import utils
|
from . import utils
|
||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import gligen
|
from . import gligen
|
||||||
|
from . import diffusers_convert
|
||||||
|
|
||||||
def load_torch_file(ckpt):
|
def load_torch_file(ckpt):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -324,15 +325,29 @@ def model_lora_keys(model, key_map={}):
|
|||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model):
|
def __init__(self, model, size=0):
|
||||||
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = []
|
self.patches = []
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
|
self.model_size()
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size > 0:
|
||||||
|
return self.size
|
||||||
|
model_sd = self.model.state_dict()
|
||||||
|
size = 0
|
||||||
|
for k in model_sd:
|
||||||
|
t = model_sd[k]
|
||||||
|
size += t.nelement() * t.element_size()
|
||||||
|
self.size = size
|
||||||
|
return size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model)
|
n = ModelPatcher(self.model, self.size)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
return n
|
return n
|
||||||
@ -553,10 +568,16 @@ class VAE:
|
|||||||
if config is None:
|
if config is None:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
|
if ckpt_path is not None:
|
||||||
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
@ -630,12 +651,9 @@ class VAE:
|
|||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
|
||||||
target_batch_size = target_latent_tensor.shape[0]
|
|
||||||
|
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
if current_batch_size == 1:
|
if current_batch_size == 1:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -652,7 +670,7 @@ def resize_image_to(tensor, target_latent_tensor, batched_number):
|
|||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
class ControlNet:
|
class ControlNet:
|
||||||
def __init__(self, control_model, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
@ -661,6 +679,7 @@ class ControlNet:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.global_average_pooling = global_average_pooling
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -672,7 +691,9 @@ class ControlNet:
|
|||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
if self.control_model.dtype == torch.float16:
|
if self.control_model.dtype == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
@ -694,6 +715,9 @@ class ControlNet:
|
|||||||
key = 'output'
|
key = 'output'
|
||||||
index = i
|
index = i
|
||||||
x = control[i]
|
x = control[i]
|
||||||
|
if self.global_average_pooling:
|
||||||
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
@ -724,7 +748,7 @@ class ControlNet:
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
return c
|
return c
|
||||||
@ -772,7 +796,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
use_spatial_transformer=True,
|
use_spatial_transformer=True,
|
||||||
transformer_depth=1,
|
transformer_depth=1,
|
||||||
context_dim=context_dim,
|
context_dim=context_dim,
|
||||||
use_checkpoint=True,
|
use_checkpoint=False,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
use_fp16=use_fp16)
|
use_fp16=use_fp16)
|
||||||
else:
|
else:
|
||||||
@ -789,7 +813,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
use_linear_in_transformer=True,
|
use_linear_in_transformer=True,
|
||||||
transformer_depth=1,
|
transformer_depth=1,
|
||||||
context_dim=context_dim,
|
context_dim=context_dim,
|
||||||
use_checkpoint=True,
|
use_checkpoint=False,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
use_fp16=use_fp16)
|
use_fp16=use_fp16)
|
||||||
if pth:
|
if pth:
|
||||||
@ -819,7 +843,11 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if use_fp16:
|
if use_fp16:
|
||||||
control_model = control_model.half()
|
control_model = control_model.half()
|
||||||
|
|
||||||
control = ControlNet(control_model)
|
global_average_pooling = False
|
||||||
|
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||||
|
global_average_pooling = True
|
||||||
|
|
||||||
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter:
|
class T2IAdapter:
|
||||||
@ -843,10 +871,14 @@ class T2IAdapter:
|
|||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
|
self.control_input = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
if self.control_input is None:
|
||||||
self.t2i_model.to(self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint)
|
self.control_input = self.t2i_model(self.cond_hint)
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
@ -1070,7 +1102,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": True,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
"out_channels": 4,
|
"out_channels": 4,
|
||||||
"attention_resolutions": [
|
"attention_resolutions": [
|
||||||
|
|||||||
@ -56,7 +56,12 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
return self.op(x)
|
if not self.use_conv:
|
||||||
|
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
||||||
|
self.op.padding = padding
|
||||||
|
|
||||||
|
x = self.op(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import struct
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
|
if safe_load:
|
||||||
|
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||||
|
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||||
|
safe_load = False
|
||||||
if safe_load:
|
if safe_load:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
else:
|
else:
|
||||||
@ -46,6 +51,88 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
|
with open(safetensors_path, "rb") as f:
|
||||||
|
header = f.read(8)
|
||||||
|
length_of_header = struct.unpack('<Q', header)[0]
|
||||||
|
if length_of_header > max_size:
|
||||||
|
return None
|
||||||
|
return f.read(length_of_header)
|
||||||
|
|
||||||
|
def bislerp(samples, width, height):
|
||||||
|
def slerp(b1, b2, r):
|
||||||
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
|
|
||||||
|
c = b1.shape[-1]
|
||||||
|
|
||||||
|
#norms
|
||||||
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||||
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
#normalize
|
||||||
|
b1_normalized = b1 / b1_norms
|
||||||
|
b2_normalized = b2 / b2_norms
|
||||||
|
|
||||||
|
#zero when norms are zero
|
||||||
|
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
|
||||||
|
#slerp
|
||||||
|
dot = (b1_normalized*b2_normalized).sum(1)
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
|
||||||
|
#technically not mathematically correct, but more pleasing?
|
||||||
|
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||||
|
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||||
|
|
||||||
|
#edge cases for same or polar opposites
|
||||||
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||||
|
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def generate_bilinear_data(length_old, length_new):
|
||||||
|
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||||
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||||
|
ratios = coords_1 - coords_1.floor()
|
||||||
|
coords_1 = coords_1.to(torch.int64)
|
||||||
|
|
||||||
|
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||||
|
coords_2[:,:,:,-1] -= 1
|
||||||
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||||
|
coords_2 = coords_2.to(torch.int64)
|
||||||
|
return ratios, coords_1, coords_2
|
||||||
|
|
||||||
|
n,c,h,w = samples.shape
|
||||||
|
h_new, w_new = (height, width)
|
||||||
|
|
||||||
|
#linear w
|
||||||
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||||
|
coords_1 = coords_1.expand((n, c, h, -1))
|
||||||
|
coords_2 = coords_2.expand((n, c, h, -1))
|
||||||
|
ratios = ratios.expand((n, 1, h, -1))
|
||||||
|
|
||||||
|
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||||
|
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||||
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||||
|
|
||||||
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
|
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||||
|
|
||||||
|
#linear h
|
||||||
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||||
|
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||||
|
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||||
|
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||||
|
|
||||||
|
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||||
|
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||||
|
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||||
|
|
||||||
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
|
return result
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[3]
|
||||||
@ -61,7 +148,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||||
else:
|
else:
|
||||||
s = samples
|
s = samples
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
|
||||||
|
if upscale_method == "bislerp":
|
||||||
|
return bislerp(s, width, height)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||||
|
|||||||
@ -0,0 +1,110 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CA_layer(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(CA_layer, self).__init__()
|
||||||
|
# global average pooling
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||||
|
# nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.fc(self.gap(x))
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Simple_CA_layer(nn.Module):
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(Simple_CA_layer, self).__init__()
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=channel,
|
||||||
|
out_channels=channel,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.fc(self.gap(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ECA_layer(nn.Module):
|
||||||
|
"""Constructs a ECA module.
|
||||||
|
Args:
|
||||||
|
channel: Number of channels of the input feature map
|
||||||
|
k_size: Adaptive selection of kernel size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(ECA_layer, self).__init__()
|
||||||
|
|
||||||
|
b = 1
|
||||||
|
gamma = 2
|
||||||
|
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||||
|
k_size = k_size if k_size % 2 else k_size + 1
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||||
|
)
|
||||||
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: input features with shape [b, c, h, w]
|
||||||
|
# b, c, h, w = x.size()
|
||||||
|
|
||||||
|
# feature descriptor on the global spatial information
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Multi-scale information fusion
|
||||||
|
# y = self.sigmoid(y)
|
||||||
|
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ECA_MaxPool_layer(nn.Module):
|
||||||
|
"""Constructs a ECA module.
|
||||||
|
Args:
|
||||||
|
channel: Number of channels of the input feature map
|
||||||
|
k_size: Adaptive selection of kernel size
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
super(ECA_MaxPool_layer, self).__init__()
|
||||||
|
|
||||||
|
b = 1
|
||||||
|
gamma = 2
|
||||||
|
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||||
|
k_size = k_size if k_size % 2 else k_size + 1
|
||||||
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||||
|
)
|
||||||
|
# self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: input features with shape [b, c, h, w]
|
||||||
|
# b, c, h, w = x.size()
|
||||||
|
|
||||||
|
# feature descriptor on the global spatial information
|
||||||
|
y = self.max_pool(x)
|
||||||
|
|
||||||
|
# Two different branches of ECA module
|
||||||
|
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Multi-scale information fusion
|
||||||
|
# y = self.sigmoid(y)
|
||||||
|
|
||||||
|
return x * y.expand_as(x)
|
||||||
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
@ -0,0 +1,577 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OSA.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
from .layernorm import LayerNorm2d
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(val, length=1):
|
||||||
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
|
||||||
|
# helper classes
|
||||||
|
|
||||||
|
|
||||||
|
class PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv_PreNormResidual(nn.Module):
|
||||||
|
def __init__(self, dim, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNorm2d(dim)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fn(self.norm(x)) + x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, inner_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv_FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=2, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Gated_Conv_FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden_features = int(dim * mult)
|
||||||
|
|
||||||
|
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
self.dwconv = nn.Conv2d(
|
||||||
|
hidden_features * 2,
|
||||||
|
hidden_features * 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=hidden_features * 2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.project_in(x)
|
||||||
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||||
|
x = F.gelu(x1) * x2
|
||||||
|
x = self.project_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# MBConv
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeExcitation(nn.Module):
|
||||||
|
def __init__(self, dim, shrinkage_rate=0.25):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(dim * shrinkage_rate)
|
||||||
|
|
||||||
|
self.gate = nn.Sequential(
|
||||||
|
Reduce("b c h w -> b c", "mean"),
|
||||||
|
nn.Linear(dim, hidden_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(hidden_dim, dim, bias=False),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
Rearrange("b c -> b c 1 1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.gate(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MBConvResidual(nn.Module):
|
||||||
|
def __init__(self, fn, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.dropsample = Dropsample(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.fn(x)
|
||||||
|
out = self.dropsample(out)
|
||||||
|
return out + x
|
||||||
|
|
||||||
|
|
||||||
|
class Dropsample(nn.Module):
|
||||||
|
def __init__(self, prob=0):
|
||||||
|
super().__init__()
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if self.prob == 0.0 or (not self.training):
|
||||||
|
return x
|
||||||
|
|
||||||
|
keep_mask = (
|
||||||
|
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||||
|
> self.prob
|
||||||
|
)
|
||||||
|
return x * keep_mask / (1 - self.prob)
|
||||||
|
|
||||||
|
|
||||||
|
def MBConv(
|
||||||
|
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||||
|
):
|
||||||
|
hidden_dim = int(expansion_rate * dim_out)
|
||||||
|
stride = 2 if downsample else 1
|
||||||
|
|
||||||
|
net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||||
|
),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||||
|
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||||
|
# nn.BatchNorm2d(dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
if dim_in == dim_out and not downsample:
|
||||||
|
net = MBConvResidual(net, dropout=dropout)
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
# attention related classes
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head=32,
|
||||||
|
dropout=0.0,
|
||||||
|
window_size=7,
|
||||||
|
with_pe=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
dim % dim_head
|
||||||
|
) == 0, "dimension should be divisible by dimension per head"
|
||||||
|
|
||||||
|
self.heads = dim // dim_head
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.with_pe = with_pe
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
# relative positional bias
|
||||||
|
if self.with_pe:
|
||||||
|
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||||
|
|
||||||
|
pos = torch.arange(window_size)
|
||||||
|
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||||
|
grid = rearrange(grid, "c i j -> (i j) c")
|
||||||
|
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||||
|
grid, "j ... -> 1 j ..."
|
||||||
|
)
|
||||||
|
rel_pos += window_size - 1
|
||||||
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch, height, width, window_height, window_width, _, device, h = (
|
||||||
|
*x.shape,
|
||||||
|
x.device,
|
||||||
|
self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
|
||||||
|
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||||
|
|
||||||
|
# project for queries, keys, values
|
||||||
|
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# sim
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
# add positional bias
|
||||||
|
if self.with_pe:
|
||||||
|
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||||
|
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
attn = self.attend(sim)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# combine heads out
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||||
|
|
||||||
|
|
||||||
|
class Block_Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head=32,
|
||||||
|
bias=False,
|
||||||
|
dropout=0.0,
|
||||||
|
window_size=7,
|
||||||
|
with_pe=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert (
|
||||||
|
dim % dim_head
|
||||||
|
) == 0, "dimension should be divisible by dimension per head"
|
||||||
|
|
||||||
|
self.heads = dim // dim_head
|
||||||
|
self.ps = window_size
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.with_pe = with_pe
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# project for queries, keys, values
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||||
|
h=self.heads,
|
||||||
|
w1=self.ps,
|
||||||
|
w2=self.ps,
|
||||||
|
),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# sim
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attn = self.attend(sim)
|
||||||
|
|
||||||
|
# aggregate
|
||||||
|
|
||||||
|
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||||
|
x=h // self.ps,
|
||||||
|
y=w // self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
w1=self.ps,
|
||||||
|
w2=self.ps,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Channel_Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||||
|
super(Channel_Attention, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||||
|
|
||||||
|
self.ps = window_size
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
qkv = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
),
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
out = attn @ v
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||||
|
h=h // self.ps,
|
||||||
|
w=w // self.ps,
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Channel_Attention_grid(nn.Module):
|
||||||
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||||
|
super(Channel_Attention_grid, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||||
|
|
||||||
|
self.ps = window_size
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3,
|
||||||
|
dim * 3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
groups=dim * 3,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
qkv = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: rearrange(
|
||||||
|
t,
|
||||||
|
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
),
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
out = attn @ v
|
||||||
|
|
||||||
|
out = rearrange(
|
||||||
|
out,
|
||||||
|
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||||
|
h=h // self.ps,
|
||||||
|
w=w // self.ps,
|
||||||
|
ph=self.ps,
|
||||||
|
pw=self.ps,
|
||||||
|
head=self.heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class OSA_Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel_num=64,
|
||||||
|
bias=True,
|
||||||
|
ffn_bias=True,
|
||||||
|
window_size=8,
|
||||||
|
with_pe=False,
|
||||||
|
dropout=0.0,
|
||||||
|
):
|
||||||
|
super(OSA_Block, self).__init__()
|
||||||
|
|
||||||
|
w = window_size
|
||||||
|
|
||||||
|
self.layer = nn.Sequential(
|
||||||
|
MBConv(
|
||||||
|
channel_num,
|
||||||
|
channel_num,
|
||||||
|
downsample=False,
|
||||||
|
expansion_rate=1,
|
||||||
|
shrinkage_rate=0.25,
|
||||||
|
),
|
||||||
|
Rearrange(
|
||||||
|
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||||
|
), # block-like attention
|
||||||
|
PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Attention(
|
||||||
|
dim=channel_num,
|
||||||
|
dim_head=channel_num // 4,
|
||||||
|
dropout=dropout,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=with_pe,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
# channel-like attention
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Channel_Attention(
|
||||||
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
Rearrange(
|
||||||
|
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||||
|
), # grid-like attention
|
||||||
|
PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Attention(
|
||||||
|
dim=channel_num,
|
||||||
|
dim_head=channel_num // 4,
|
||||||
|
dropout=dropout,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=with_pe,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
# channel-like attention
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num,
|
||||||
|
Channel_Attention_grid(
|
||||||
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Conv_PreNormResidual(
|
||||||
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.layer(x)
|
||||||
|
return out
|
||||||
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OSAG.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .esa import ESA
|
||||||
|
from .OSA import OSA_Block
|
||||||
|
|
||||||
|
|
||||||
|
class OSAG(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel_num=64,
|
||||||
|
bias=True,
|
||||||
|
block_num=4,
|
||||||
|
ffn_bias=False,
|
||||||
|
window_size=0,
|
||||||
|
pe=False,
|
||||||
|
):
|
||||||
|
super(OSAG, self).__init__()
|
||||||
|
|
||||||
|
# print("window_size: %d" % (window_size))
|
||||||
|
# print("with_pe", pe)
|
||||||
|
# print("ffn_bias: %d" % (ffn_bias))
|
||||||
|
|
||||||
|
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||||
|
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||||
|
|
||||||
|
# script_name = "." + block_script_name
|
||||||
|
# package = __import__(script_name, fromlist=True)
|
||||||
|
block_class = OSA_Block # getattr(package, block_class_name)
|
||||||
|
group_list = []
|
||||||
|
for _ in range(block_num):
|
||||||
|
temp_res = block_class(
|
||||||
|
channel_num,
|
||||||
|
bias,
|
||||||
|
ffn_bias=ffn_bias,
|
||||||
|
window_size=window_size,
|
||||||
|
with_pe=pe,
|
||||||
|
)
|
||||||
|
group_list.append(temp_res)
|
||||||
|
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||||
|
self.residual_layer = nn.Sequential(*group_list)
|
||||||
|
esa_channel = max(channel_num // 4, 16)
|
||||||
|
self.esa = ESA(esa_channel, channel_num)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.residual_layer(x)
|
||||||
|
out = out + x
|
||||||
|
return self.esa(out)
|
||||||
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: OmniSR.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .OSAG import OSAG
|
||||||
|
from .pixelshuffle import pixelshuffle_block
|
||||||
|
|
||||||
|
|
||||||
|
class OmniSR(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(OmniSR, self).__init__()
|
||||||
|
self.state = state_dict
|
||||||
|
|
||||||
|
bias = True # Fine to assume this for now
|
||||||
|
block_num = 1 # Fine to assume this for now
|
||||||
|
ffn_bias = True
|
||||||
|
pe = True
|
||||||
|
|
||||||
|
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||||
|
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||||
|
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||||
|
|
||||||
|
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||||
|
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||||
|
if up_scale - int(up_scale) > 0:
|
||||||
|
print(
|
||||||
|
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||||
|
)
|
||||||
|
up_scale = int(up_scale)
|
||||||
|
res_num = 0
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "residual_layer" in key:
|
||||||
|
temp_res_num = int(key.split(".")[1])
|
||||||
|
if temp_res_num > res_num:
|
||||||
|
res_num = temp_res_num
|
||||||
|
res_num = res_num + 1 # zero-indexed
|
||||||
|
|
||||||
|
residual_layer = []
|
||||||
|
self.res_num = res_num
|
||||||
|
|
||||||
|
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||||
|
self.up_scale = up_scale
|
||||||
|
|
||||||
|
for _ in range(res_num):
|
||||||
|
temp_res = OSAG(
|
||||||
|
channel_num=num_feat,
|
||||||
|
bias=bias,
|
||||||
|
block_num=block_num,
|
||||||
|
ffn_bias=ffn_bias,
|
||||||
|
window_size=self.window_size,
|
||||||
|
pe=pe,
|
||||||
|
)
|
||||||
|
residual_layer.append(temp_res)
|
||||||
|
self.residual_layer = nn.Sequential(*residual_layer)
|
||||||
|
self.input = nn.Conv2d(
|
||||||
|
in_channels=num_in_ch,
|
||||||
|
out_channels=num_feat,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.output = nn.Conv2d(
|
||||||
|
in_channels=num_feat,
|
||||||
|
out_channels=num_feat,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||||
|
|
||||||
|
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||||
|
|
||||||
|
# for m in self.modules():
|
||||||
|
# if isinstance(m, nn.Conv2d):
|
||||||
|
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||||
|
|
||||||
|
# chaiNNer specific stuff
|
||||||
|
self.model_arch = "OmniSR"
|
||||||
|
self.sub_type = "SR"
|
||||||
|
self.in_nc = num_in_ch
|
||||||
|
self.out_nc = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.scale = up_scale
|
||||||
|
|
||||||
|
self.supports_fp16 = True # TODO: Test this
|
||||||
|
self.supports_bfp16 = True
|
||||||
|
self.min_size_restriction = 16
|
||||||
|
|
||||||
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
def check_image_size(self, x):
|
||||||
|
_, _, h, w = x.size()
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||||
|
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||||
|
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||||
|
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
H, W = x.shape[2:]
|
||||||
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
|
residual = self.input(x)
|
||||||
|
out = self.residual_layer(residual)
|
||||||
|
|
||||||
|
# origin
|
||||||
|
out = torch.add(self.output(out), residual)
|
||||||
|
out = self.up(out)
|
||||||
|
|
||||||
|
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||||
|
return out
|
||||||
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: esa.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .layernorm import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
def moment(x, dim=(2, 3), k=2):
|
||||||
|
assert len(x.size()) == 4
|
||||||
|
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||||
|
return mk
|
||||||
|
|
||||||
|
|
||||||
|
class ESA(nn.Module):
|
||||||
|
"""
|
||||||
|
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||||
|
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||||
|
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||||
|
are deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||||
|
super(ESA, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||||
|
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.conv1(x)
|
||||||
|
c1 = self.conv2(c1_)
|
||||||
|
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||||
|
c3 = self.conv3(v_max)
|
||||||
|
c3 = F.interpolate(
|
||||||
|
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(c3 + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class LK_ESA(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(LK_ESA, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.vec_conv3x1 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, 3),
|
||||||
|
padding=(0, 1),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.hor_conv1x3 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(3, 1),
|
||||||
|
padding=(1, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.conv1(x)
|
||||||
|
|
||||||
|
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||||
|
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||||
|
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(res + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class LK_ESA_LN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(LK_ESA_LN, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.norm = LayerNorm2d(n_feats)
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.vec_conv3x1 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(1, 3),
|
||||||
|
padding=(0, 1),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.hor_conv1x3 = nn.Conv2d(
|
||||||
|
in_channels=f * kernel_expand,
|
||||||
|
out_channels=f * kernel_expand,
|
||||||
|
kernel_size=(3, 1),
|
||||||
|
padding=(1, 0),
|
||||||
|
groups=2,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1_ = self.norm(x)
|
||||||
|
c1_ = self.conv1(c1_)
|
||||||
|
|
||||||
|
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||||
|
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||||
|
|
||||||
|
cf = self.conv_f(c1_)
|
||||||
|
c4 = self.conv4(res + cf)
|
||||||
|
m = self.sigmoid(c4)
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class AdaGuidedFilter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(AdaGuidedFilter, self).__init__()
|
||||||
|
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=n_feats,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.r = 5
|
||||||
|
|
||||||
|
def box_filter(self, x, r):
|
||||||
|
channel = x.shape[1]
|
||||||
|
kernel_size = 2 * r + 1
|
||||||
|
weight = 1.0 / (kernel_size**2)
|
||||||
|
box_kernel = weight * torch.ones(
|
||||||
|
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||||
|
)
|
||||||
|
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
N = self.box_filter(
|
||||||
|
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||||
|
)
|
||||||
|
|
||||||
|
# epsilon = self.fc(self.gap(x))
|
||||||
|
# epsilon = torch.pow(epsilon, 2)
|
||||||
|
epsilon = 1e-2
|
||||||
|
|
||||||
|
mean_x = self.box_filter(x, self.r) / N
|
||||||
|
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||||
|
|
||||||
|
A = var_x / (var_x + epsilon)
|
||||||
|
b = (1 - A) * mean_x
|
||||||
|
m = A * x + b
|
||||||
|
|
||||||
|
# mean_A = self.box_filter(A, self.r) / N
|
||||||
|
# mean_b = self.box_filter(b, self.r) / N
|
||||||
|
# m = mean_A * x + mean_b
|
||||||
|
return x * m
|
||||||
|
|
||||||
|
|
||||||
|
class AdaConvGuidedFilter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||||
|
):
|
||||||
|
super(AdaConvGuidedFilter, self).__init__()
|
||||||
|
f = esa_channels
|
||||||
|
|
||||||
|
self.conv_f = conv(f, f, kernel_size=1)
|
||||||
|
|
||||||
|
kernel_size = 17
|
||||||
|
kernel_expand = kernel_expand
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
|
self.vec_conv = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=(1, kernel_size),
|
||||||
|
padding=(0, padding),
|
||||||
|
groups=f,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hor_conv = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
groups=f,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
in_channels=f,
|
||||||
|
out_channels=f,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.vec_conv(x)
|
||||||
|
y = self.hor_conv(y)
|
||||||
|
|
||||||
|
sigma = torch.pow(y, 2)
|
||||||
|
epsilon = self.fc(self.gap(y))
|
||||||
|
|
||||||
|
weight = sigma / (sigma + epsilon)
|
||||||
|
|
||||||
|
m = weight * x + (1 - weight)
|
||||||
|
|
||||||
|
return x * m
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: layernorm.py
|
||||||
|
# Created Date: Tuesday April 28th 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, weight, bias, eps):
|
||||||
|
ctx.eps = eps
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
mu = x.mean(1, keepdim=True)
|
||||||
|
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||||
|
y = (x - mu) / (var + eps).sqrt()
|
||||||
|
ctx.save_for_backward(y, var, weight)
|
||||||
|
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
eps = ctx.eps
|
||||||
|
|
||||||
|
N, C, H, W = grad_output.size()
|
||||||
|
y, var, weight = ctx.saved_variables
|
||||||
|
g = grad_output * weight.view(1, C, 1, 1)
|
||||||
|
mean_g = g.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||||
|
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||||
|
return (
|
||||||
|
gx,
|
||||||
|
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||||
|
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2d(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-6):
|
||||||
|
super(LayerNorm2d, self).__init__()
|
||||||
|
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||||
|
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class GRN(nn.Module):
|
||||||
|
"""GRN (Global Response Normalization) layer"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#############################################################
|
||||||
|
# File: pixelshuffle.py
|
||||||
|
# Created Date: Friday July 1st 2022
|
||||||
|
# Author: Chen Xuanhong
|
||||||
|
# Email: chenxuanhongzju@outlook.com
|
||||||
|
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||||
|
# Modified By: Chen Xuanhong
|
||||||
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(
|
||||||
|
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upsample features according to `upscale_factor`.
|
||||||
|
"""
|
||||||
|
padding = kernel_size // 2
|
||||||
|
conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels * (upscale_factor**2),
|
||||||
|
kernel_size,
|
||||||
|
padding=1,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
return nn.Sequential(*[conv, pixel_shuffle])
|
||||||
@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
|
|||||||
self.scale: int = self.get_scale()
|
self.scale: int = self.get_scale()
|
||||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||||
|
|
||||||
|
c2x2 = False
|
||||||
|
if self.state["model.0.weight"].shape[-2] == 2:
|
||||||
|
c2x2 = True
|
||||||
|
self.scale = round(math.sqrt(self.scale / 4))
|
||||||
|
self.model_arch = "ESRGAN-2c2"
|
||||||
|
|
||||||
self.supports_fp16 = True
|
self.supports_fp16 = True
|
||||||
self.supports_bfp16 = True
|
self.supports_bfp16 = True
|
||||||
self.min_size_restriction = None
|
self.min_size_restriction = None
|
||||||
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
|
|||||||
out_nc=self.num_filters,
|
out_nc=self.num_filters,
|
||||||
upscale_factor=3,
|
upscale_factor=3,
|
||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
upsample_blocks = [
|
upsample_blocks = [
|
||||||
upsample_block(
|
upsample_block(
|
||||||
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
|
in_nc=self.num_filters,
|
||||||
|
out_nc=self.num_filters,
|
||||||
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
for _ in range(int(math.log(self.scale, 2)))
|
for _ in range(int(math.log(self.scale, 2)))
|
||||||
]
|
]
|
||||||
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
B.ShortcutBlock(
|
B.ShortcutBlock(
|
||||||
B.sequential(
|
B.sequential(
|
||||||
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
|
|||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
mode="CNA",
|
mode="CNA",
|
||||||
plus=self.plus,
|
plus=self.plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
for _ in range(self.num_blocks)
|
for _ in range(self.num_blocks)
|
||||||
],
|
],
|
||||||
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
|
|||||||
norm_type=self.norm,
|
norm_type=self.norm,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=self.act,
|
act_type=self.act,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
# hr_conv1
|
# hr_conv1
|
||||||
B.conv_block(
|
B.conv_block(
|
||||||
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
norm_type=None,
|
norm_type=None,
|
||||||
act_type=None,
|
act_type=None,
|
||||||
|
c2x2=c2x2,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -141,6 +141,19 @@ def sequential(*args):
|
|||||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||||
|
|
||||||
|
|
||||||
|
# 2x2x2 Conv Block
|
||||||
|
def conv_block_2c2(
|
||||||
|
in_nc,
|
||||||
|
out_nc,
|
||||||
|
act_type="relu",
|
||||||
|
):
|
||||||
|
return sequential(
|
||||||
|
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||||
|
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||||
|
act(act_type) if act_type else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def conv_block(
|
def conv_block(
|
||||||
in_nc: int,
|
in_nc: int,
|
||||||
out_nc: int,
|
out_nc: int,
|
||||||
@ -153,12 +166,17 @@ def conv_block(
|
|||||||
norm_type: str | None = None,
|
norm_type: str | None = None,
|
||||||
act_type: str | None = "relu",
|
act_type: str | None = "relu",
|
||||||
mode: ConvMode = "CNA",
|
mode: ConvMode = "CNA",
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Conv layer with padding, normalization, activation
|
Conv layer with padding, normalization, activation
|
||||||
mode: CNA --> Conv -> Norm -> Act
|
mode: CNA --> Conv -> Norm -> Act
|
||||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if c2x2:
|
||||||
|
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||||
|
|
||||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||||
@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
|||||||
_convtype="Conv2D",
|
_convtype="Conv2D",
|
||||||
_spectral_norm=False,
|
_spectral_norm=False,
|
||||||
plus=False,
|
plus=False,
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(
|
self.RDB1 = ResidualDenseBlock_5C(
|
||||||
@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
|||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.RDB2 = ResidualDenseBlock_5C(
|
self.RDB2 = ResidualDenseBlock_5C(
|
||||||
nf,
|
nf,
|
||||||
@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
|||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.RDB3 = ResidualDenseBlock_5C(
|
self.RDB3 = ResidualDenseBlock_5C(
|
||||||
nf,
|
nf,
|
||||||
@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
|||||||
act_type,
|
act_type,
|
||||||
mode,
|
mode,
|
||||||
plus=plus,
|
plus=plus,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
act_type="leakyrelu",
|
act_type="leakyrelu",
|
||||||
mode: ConvMode = "CNA",
|
mode: ConvMode = "CNA",
|
||||||
plus=False,
|
plus=False,
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
|
||||||
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv2 = conv_block(
|
self.conv2 = conv_block(
|
||||||
nf + gc,
|
nf + gc,
|
||||||
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv3 = conv_block(
|
self.conv3 = conv_block(
|
||||||
nf + 2 * gc,
|
nf + 2 * gc,
|
||||||
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
self.conv4 = conv_block(
|
self.conv4 = conv_block(
|
||||||
nf + 3 * gc,
|
nf + 3 * gc,
|
||||||
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
if mode == "CNA":
|
if mode == "CNA":
|
||||||
last_act = None
|
last_act = None
|
||||||
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=last_act,
|
act_type=last_act,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -499,6 +527,7 @@ def upconv_block(
|
|||||||
norm_type: str | None = None,
|
norm_type: str | None = None,
|
||||||
act_type="relu",
|
act_type="relu",
|
||||||
mode="nearest",
|
mode="nearest",
|
||||||
|
c2x2=False,
|
||||||
):
|
):
|
||||||
# Up conv
|
# Up conv
|
||||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||||
@ -512,5 +541,6 @@ def upconv_block(
|
|||||||
pad_type=pad_type,
|
pad_type=pad_type,
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
act_type=act_type,
|
act_type=act_type,
|
||||||
|
c2x2=c2x2,
|
||||||
)
|
)
|
||||||
return sequential(upsample, conv)
|
return sequential(upsample, conv)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
|||||||
from .architecture.HAT import HAT
|
from .architecture.HAT import HAT
|
||||||
from .architecture.LaMa import LaMa
|
from .architecture.LaMa import LaMa
|
||||||
from .architecture.MAT import MAT
|
from .architecture.MAT import MAT
|
||||||
|
from .architecture.OmniSR.OmniSR import OmniSR
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
from .architecture.SPSR import SPSRNet as SPSR
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||||
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
|||||||
state_dict = state_dict["params"]
|
state_dict = state_dict["params"]
|
||||||
|
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
# SRVGGNet Real-ESRGAN (v2)
|
# SRVGGNet Real-ESRGAN (v2)
|
||||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||||
model = RealESRGANv2(state_dict)
|
model = RealESRGANv2(state_dict)
|
||||||
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
|||||||
# MAT
|
# MAT
|
||||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||||
model = MAT(state_dict)
|
model = MAT(state_dict)
|
||||||
|
# Omni-SR
|
||||||
|
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||||
|
model = OmniSR(state_dict)
|
||||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
|||||||
from .architecture.HAT import HAT
|
from .architecture.HAT import HAT
|
||||||
from .architecture.LaMa import LaMa
|
from .architecture.LaMa import LaMa
|
||||||
from .architecture.MAT import MAT
|
from .architecture.MAT import MAT
|
||||||
|
from .architecture.OmniSR.OmniSR import OmniSR
|
||||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||||
from .architecture.SPSR import SPSRNet as SPSR
|
from .architecture.SPSR import SPSRNet as SPSR
|
||||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||||
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
|||||||
from .architecture.Swin2SR import Swin2SR
|
from .architecture.Swin2SR import Swin2SR
|
||||||
from .architecture.SwinIR import SwinIR
|
from .architecture.SwinIR import SwinIR
|
||||||
|
|
||||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
|
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
|
||||||
PyTorchSRModel = Union[
|
PyTorchSRModel = Union[
|
||||||
RealESRGANv2,
|
RealESRGANv2,
|
||||||
SPSR,
|
SPSR,
|
||||||
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
|
|||||||
SwinIR,
|
SwinIR,
|
||||||
Swin2SR,
|
Swin2SR,
|
||||||
HAT,
|
HAT,
|
||||||
|
OmniSR,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class MaskToImage:
|
|||||||
FUNCTION = "mask_to_image"
|
FUNCTION = "mask_to_image"
|
||||||
|
|
||||||
def mask_to_image(self, mask):
|
def mask_to_image(self, mask):
|
||||||
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
|
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
class ImageToMask:
|
class ImageToMask:
|
||||||
@ -167,7 +167,7 @@ class MaskComposite:
|
|||||||
"source": ("MASK",),
|
"source": ("MASK",),
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
"operation": (["multiply", "add", "subtract"],),
|
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,6 +193,12 @@ class MaskComposite:
|
|||||||
output[top:bottom, left:right] = destination_portion + source_portion
|
output[top:bottom, left:right] = destination_portion + source_portion
|
||||||
elif operation == "subtract":
|
elif operation == "subtract":
|
||||||
output[top:bottom, left:right] = destination_portion - source_portion
|
output[top:bottom, left:right] = destination_portion - source_portion
|
||||||
|
elif operation == "and":
|
||||||
|
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||||
|
elif operation == "or":
|
||||||
|
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||||
|
elif operation == "xor":
|
||||||
|
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||||
|
|
||||||
output = torch.clamp(output, 0.0, 1.0)
|
output = torch.clamp(output, 0.0, 1.0)
|
||||||
|
|
||||||
|
|||||||
@ -59,6 +59,12 @@ class Blend:
|
|||||||
def g(self, x):
|
def g(self, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
|
def gaussian_kernel(kernel_size: int, sigma: float):
|
||||||
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||||
|
d = torch.sqrt(x * x + y * y)
|
||||||
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
|
return g / g.sum()
|
||||||
|
|
||||||
class Blur:
|
class Blur:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@ -88,12 +94,6 @@ class Blur:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
|
||||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
|
||||||
d = torch.sqrt(x * x + y * y)
|
|
||||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
|
||||||
return g / g.sum()
|
|
||||||
|
|
||||||
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||||
if blur_radius == 0:
|
if blur_radius == 0:
|
||||||
return (image,)
|
return (image,)
|
||||||
@ -101,10 +101,11 @@ class Blur:
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = blur_radius * 2 + 1
|
kernel_size = blur_radius * 2 + 1
|
||||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||||
|
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||||
blurred = blurred.permute(0, 2, 3, 1)
|
blurred = blurred.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
return (blurred,)
|
return (blurred,)
|
||||||
@ -167,9 +168,15 @@ class Sharpen:
|
|||||||
"max": 31,
|
"max": 31,
|
||||||
"step": 1
|
"step": 1
|
||||||
}),
|
}),
|
||||||
"alpha": ("FLOAT", {
|
"sigma": ("FLOAT", {
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
"min": 0.1,
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.0,
|
||||||
"max": 5.0,
|
"max": 5.0,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
}),
|
}),
|
||||||
@ -181,21 +188,21 @@ class Sharpen:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||||
if sharpen_radius == 0:
|
if sharpen_radius == 0:
|
||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = sharpen_radius * 2 + 1
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
kernel[center, center] = kernel_size**2
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
kernel *= alpha
|
|
||||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||||
|
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
result = torch.clamp(sharpened, 0, 1)
|
result = torch.clamp(sharpened, 0, 1)
|
||||||
|
|||||||
108
comfy_extras/nodes_rebatch.py
Normal file
108
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class LatentRebatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "latents": ("LATENT",),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
OUTPUT_IS_LIST = (True, )
|
||||||
|
|
||||||
|
FUNCTION = "rebatch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_batch(latents, list_ind, offset):
|
||||||
|
'''prepare a batch out of the list of latents'''
|
||||||
|
samples = latents[list_ind]['samples']
|
||||||
|
shape = samples.shape
|
||||||
|
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||||
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||||
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||||
|
if mask.shape[0] < samples.shape[0]:
|
||||||
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||||
|
if 'batch_index' in latents[list_ind]:
|
||||||
|
batch_inds = latents[list_ind]['batch_index']
|
||||||
|
else:
|
||||||
|
batch_inds = [x+offset for x in range(shape[0])]
|
||||||
|
return samples, mask, batch_inds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_slices(indexable, num, batch_size):
|
||||||
|
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||||
|
slices = []
|
||||||
|
for i in range(num):
|
||||||
|
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||||
|
if num * batch_size < len(indexable):
|
||||||
|
return slices, indexable[num * batch_size:]
|
||||||
|
else:
|
||||||
|
return slices, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def slice_batch(batch, num, batch_size):
|
||||||
|
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||||
|
return list(zip(*result))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat_batch(batch1, batch2):
|
||||||
|
if batch1[0] is None:
|
||||||
|
return batch2
|
||||||
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def rebatch(self, latents, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
current_batch = (None, None, None)
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
for i in range(len(latents)):
|
||||||
|
# fetch new entry of list
|
||||||
|
#samples, masks, indices = self.get_batch(latents, i)
|
||||||
|
next_batch = self.get_batch(latents, i, processed)
|
||||||
|
processed += len(next_batch[2])
|
||||||
|
# set to current if current is None
|
||||||
|
if current_batch[0] is None:
|
||||||
|
current_batch = next_batch
|
||||||
|
# add previous to list if dimensions do not match
|
||||||
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||||
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
current_batch = next_batch
|
||||||
|
# cat if everything checks out
|
||||||
|
else:
|
||||||
|
current_batch = self.cat_batch(current_batch, next_batch)
|
||||||
|
|
||||||
|
# add to list if dimensions gone above target batch size
|
||||||
|
if current_batch[0].shape[0] > batch_size:
|
||||||
|
num = current_batch[0].shape[0] // batch_size
|
||||||
|
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||||
|
|
||||||
|
current_batch = remainder
|
||||||
|
|
||||||
|
#add remainder
|
||||||
|
if current_batch[0] is not None:
|
||||||
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
|
||||||
|
#get rid of empty masks
|
||||||
|
for s in output_list:
|
||||||
|
if s['noise_mask'].mean() == 1.0:
|
||||||
|
del s['noise_mask']
|
||||||
|
|
||||||
|
return (output_list,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"RebatchLatents": LatentRebatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"RebatchLatents": "Rebatch Latents",
|
||||||
|
}
|
||||||
@ -17,7 +17,7 @@ class UpscaleModelLoader:
|
|||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
out = model_loading.load_state_dict(sd).eval()
|
out = model_loading.load_state_dict(sd).eval()
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
|||||||
559
execution.py
559
execution.py
@ -6,6 +6,7 @@ import threading
|
|||||||
import heapq
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
import gc
|
import gc
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
@ -26,27 +27,96 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
else:
|
else:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||||
input_data_all[x] = input_data
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
for x in h:
|
for x in h:
|
||||||
if h[x] == "PROMPT":
|
if h[x] == "PROMPT":
|
||||||
input_data_all[x] = prompt
|
input_data_all[x] = [prompt]
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
if "extra_pnginfo" in extra_data:
|
if "extra_pnginfo" in extra_data:
|
||||||
input_data_all[x] = extra_data['extra_pnginfo']
|
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = unique_id
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
|
# check if node wants the lists
|
||||||
|
intput_is_list = False
|
||||||
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
|
intput_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
|
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||||
|
|
||||||
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
|
def slice_dict(d, i):
|
||||||
|
d_new = dict()
|
||||||
|
for k,v in d.items():
|
||||||
|
d_new[k] = v[i if len(v) > i else -1]
|
||||||
|
return d_new
|
||||||
|
|
||||||
|
results = []
|
||||||
|
if intput_is_list:
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
|
else:
|
||||||
|
for i in range(max_len_input):
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all):
|
||||||
|
|
||||||
|
results = []
|
||||||
|
uis = []
|
||||||
|
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||||
|
|
||||||
|
for r in return_values:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
if 'ui' in r:
|
||||||
|
uis.append(r['ui'])
|
||||||
|
if 'result' in r:
|
||||||
|
results.append(r['result'])
|
||||||
|
else:
|
||||||
|
results.append(r)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
if len(results) > 0:
|
||||||
|
# check which outputs need concatenating
|
||||||
|
output_is_list = [False] * len(results[0])
|
||||||
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||||
|
output_is_list = obj.OUTPUT_IS_LIST
|
||||||
|
|
||||||
|
# merge node execution results
|
||||||
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
|
if is_list:
|
||||||
|
output.append([x for o in results for x in o[i]])
|
||||||
|
else:
|
||||||
|
output.append([o[i] for o in results])
|
||||||
|
|
||||||
|
ui = dict()
|
||||||
|
if len(uis) > 0:
|
||||||
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
|
return output, ui
|
||||||
|
|
||||||
|
def format_value(x):
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
elif isinstance(x, (int, float, bool, str)):
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if unique_id in outputs:
|
||||||
return
|
return (True, None, None)
|
||||||
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
@ -55,23 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||||
|
if result[0] is not True:
|
||||||
|
# Another node failed further upstream
|
||||||
|
return result
|
||||||
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = None
|
||||||
if server.client_id is not None:
|
try:
|
||||||
server.last_node_id = unique_id
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
|
||||||
obj = class_def()
|
|
||||||
|
|
||||||
nodes.before_node_execution()
|
|
||||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
|
||||||
if "ui" in outputs[unique_id]:
|
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
server.last_node_id = unique_id
|
||||||
if "result" in outputs[unique_id]:
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
outputs[unique_id] = outputs[unique_id]["result"]
|
obj = class_def()
|
||||||
|
|
||||||
|
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||||
|
outputs[unique_id] = output_data
|
||||||
|
if len(output_ui) > 0:
|
||||||
|
outputs_ui[unique_id] = output_ui
|
||||||
|
if server.client_id is not None:
|
||||||
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
|
print("Processing interrupted")
|
||||||
|
|
||||||
|
# skip formatting inputs/outputs
|
||||||
|
error_details = {
|
||||||
|
"node_id": unique_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (False, error_details, iex)
|
||||||
|
except Exception as ex:
|
||||||
|
typ, _, tb = sys.exc_info()
|
||||||
|
exception_type = full_type_name(typ)
|
||||||
|
input_data_formatted = {}
|
||||||
|
if input_data_all is not None:
|
||||||
|
input_data_formatted = {}
|
||||||
|
for name, inputs in input_data_all.items():
|
||||||
|
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||||
|
|
||||||
|
output_data_formatted = {}
|
||||||
|
for node_id, node_outputs in outputs.items():
|
||||||
|
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||||
|
|
||||||
|
print("!!! Exception during processing !!!")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
error_details = {
|
||||||
|
"node_id": unique_id,
|
||||||
|
"exception_message": str(ex),
|
||||||
|
"exception_type": exception_type,
|
||||||
|
"traceback": traceback.format_tb(tb),
|
||||||
|
"current_inputs": input_data_formatted,
|
||||||
|
"current_outputs": output_data_formatted
|
||||||
|
}
|
||||||
|
return (False, error_details, ex)
|
||||||
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
|
return (True, None, None)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item):
|
def recursive_will_execute(prompt, outputs, current_item):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
@ -105,7 +216,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||||
if input_data_all is not None:
|
if input_data_all is not None:
|
||||||
try:
|
try:
|
||||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
except:
|
except:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
@ -144,10 +256,53 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
def execute(self, prompt, extra_data={}):
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||||
|
node_id = error["node_id"]
|
||||||
|
class_type = prompt[node_id]["class_type"]
|
||||||
|
|
||||||
|
# First, send back the status to the frontend depending
|
||||||
|
# on the exception type
|
||||||
|
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||||
|
mes = {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"node_id": node_id,
|
||||||
|
"node_type": class_type,
|
||||||
|
"executed": list(executed),
|
||||||
|
}
|
||||||
|
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
||||||
|
else:
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
mes = {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"node_id": node_id,
|
||||||
|
"node_type": class_type,
|
||||||
|
"executed": list(executed),
|
||||||
|
|
||||||
|
"exception_message": error["exception_message"],
|
||||||
|
"exception_type": error["exception_type"],
|
||||||
|
"traceback": error["traceback"],
|
||||||
|
"current_inputs": error["current_inputs"],
|
||||||
|
"current_outputs": error["current_outputs"],
|
||||||
|
}
|
||||||
|
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||||
|
|
||||||
|
# Next, remove the subsequent outputs since they will not be executed
|
||||||
|
to_delete = []
|
||||||
|
for o in self.outputs:
|
||||||
|
if (o not in current_outputs) and (o not in executed):
|
||||||
|
to_delete += [o]
|
||||||
|
if o in self.old_prompt:
|
||||||
|
d = self.old_prompt.pop(o)
|
||||||
|
del d
|
||||||
|
for o in to_delete:
|
||||||
|
d = self.outputs.pop(o)
|
||||||
|
del d
|
||||||
|
|
||||||
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
@ -155,6 +310,10 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
|
execution_start_time = time.perf_counter()
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
#delete cached outputs if nodes don't exist for them
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -169,105 +328,250 @@ class PromptExecutor:
|
|||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
|
|
||||||
current_outputs = set(self.outputs.keys())
|
current_outputs = set(self.outputs.keys())
|
||||||
executed = set()
|
for x in list(self.outputs_ui.keys()):
|
||||||
try:
|
if x not in current_outputs:
|
||||||
to_execute = []
|
d = self.outputs_ui.pop(x)
|
||||||
for x in prompt:
|
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
to_execute += [(0, x)]
|
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
|
||||||
x = to_execute.pop(0)[-1]
|
|
||||||
|
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
if class_.OUTPUT_NODE == True:
|
|
||||||
valid = False
|
|
||||||
try:
|
|
||||||
m = validate_inputs(prompt, x)
|
|
||||||
valid = m[0]
|
|
||||||
except:
|
|
||||||
valid = False
|
|
||||||
if valid:
|
|
||||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
to_delete = []
|
|
||||||
for o in self.outputs:
|
|
||||||
if (o not in current_outputs) and (o not in executed):
|
|
||||||
to_delete += [o]
|
|
||||||
if o in self.old_prompt:
|
|
||||||
d = self.old_prompt.pop(o)
|
|
||||||
del d
|
|
||||||
for o in to_delete:
|
|
||||||
d = self.outputs.pop(o)
|
|
||||||
del d
|
del d
|
||||||
finally:
|
|
||||||
for x in executed:
|
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
|
||||||
self.server.last_node_id = None
|
|
||||||
if self.server.client_id is not None:
|
|
||||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
|
||||||
|
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
executed = set()
|
||||||
|
output_node_id = None
|
||||||
|
to_execute = []
|
||||||
|
|
||||||
|
for node_id in list(execute_outputs):
|
||||||
|
to_execute += [(0, node_id)]
|
||||||
|
|
||||||
|
while len(to_execute) > 0:
|
||||||
|
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||||
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||||
|
output_node_id = to_execute.pop(0)[-1]
|
||||||
|
|
||||||
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
|
# the actual SD code, instead it will report the node where the
|
||||||
|
# error was raised
|
||||||
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||||
|
if success is not True:
|
||||||
|
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||||
|
break
|
||||||
|
|
||||||
|
for x in executed:
|
||||||
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
|
self.server.last_node_id = None
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||||
|
|
||||||
|
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||||
gc.collect()
|
gc.collect()
|
||||||
comfy.model_management.soft_empty_cache()
|
comfy.model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item):
|
def validate_inputs(prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
|
if unique_id in validated:
|
||||||
|
return validated[unique_id]
|
||||||
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
class_inputs = obj_class.INPUT_TYPES()
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
required_inputs = class_inputs['required']
|
required_inputs = class_inputs['required']
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
valid = True
|
||||||
|
|
||||||
for x in required_inputs:
|
for x in required_inputs:
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
return (False, "Required input is missing. {}, {}".format(class_type, x))
|
error = {
|
||||||
|
"type": "required_input_missing",
|
||||||
|
"message": "Required input is missing",
|
||||||
|
"details": f"{x}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = required_inputs[x]
|
info = required_inputs[x]
|
||||||
type_input = info[0]
|
type_input = info[0]
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
return (False, "Bad Input. {}, {}".format(class_type, x))
|
error = {
|
||||||
|
"type": "bad_linked_input",
|
||||||
|
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||||
|
"details": f"{x}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
o_id = val[0]
|
o_id = val[0]
|
||||||
o_class_type = prompt[o_id]['class_type']
|
o_class_type = prompt[o_id]['class_type']
|
||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
if r[val[1]] != type_input:
|
if r[val[1]] != type_input:
|
||||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
received_type = r[val[1]]
|
||||||
r = validate_inputs(prompt, o_id)
|
details = f"{x}, {received_type} != {type_input}"
|
||||||
if r[0] == False:
|
error = {
|
||||||
return r
|
"type": "return_type_mismatch",
|
||||||
|
"message": "Return type mismatch between linked nodes",
|
||||||
|
"details": details,
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_type": received_type,
|
||||||
|
"linked_node": val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
r = validate_inputs(prompt, o_id, validated)
|
||||||
|
if r[0] is False:
|
||||||
|
# `r` will be set in `validated[o_id]` already
|
||||||
|
valid = False
|
||||||
|
continue
|
||||||
|
except Exception as ex:
|
||||||
|
typ, _, tb = sys.exc_info()
|
||||||
|
valid = False
|
||||||
|
exception_type = full_type_name(typ)
|
||||||
|
reasons = [{
|
||||||
|
"type": "exception_during_inner_validation",
|
||||||
|
"message": "Exception when validating inner node",
|
||||||
|
"details": str(ex),
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"exception_message": str(ex),
|
||||||
|
"exception_type": exception_type,
|
||||||
|
"traceback": traceback.format_tb(tb),
|
||||||
|
"linked_node": val
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
validated[o_id] = (False, reasons, o_id)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
if type_input == "INT":
|
try:
|
||||||
val = int(val)
|
if type_input == "INT":
|
||||||
inputs[x] = val
|
val = int(val)
|
||||||
if type_input == "FLOAT":
|
inputs[x] = val
|
||||||
val = float(val)
|
if type_input == "FLOAT":
|
||||||
inputs[x] = val
|
val = float(val)
|
||||||
if type_input == "STRING":
|
inputs[x] = val
|
||||||
val = str(val)
|
if type_input == "STRING":
|
||||||
inputs[x] = val
|
val = str(val)
|
||||||
|
inputs[x] = val
|
||||||
|
except Exception as ex:
|
||||||
|
error = {
|
||||||
|
"type": "invalid_input_type",
|
||||||
|
"message": f"Failed to convert an input value to a {type_input} value",
|
||||||
|
"details": f"{x}, {val}, {ex}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val,
|
||||||
|
"exception_message": str(ex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
if len(info) > 1:
|
if len(info) > 1:
|
||||||
if "min" in info[1] and val < info[1]["min"]:
|
if "min" in info[1] and val < info[1]["min"]:
|
||||||
return (False, "Value smaller than min. {}, {}".format(class_type, x))
|
error = {
|
||||||
|
"type": "value_smaller_than_min",
|
||||||
|
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||||
|
"details": f"{x}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in info[1] and val > info[1]["max"]:
|
||||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
error = {
|
||||||
|
"type": "value_bigger_than_max",
|
||||||
|
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||||
|
"details": f"{x}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
if ret != True:
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||||
return (False, "{}, {}".format(class_type, ret))
|
for i, r in enumerate(ret):
|
||||||
|
if r is not True:
|
||||||
|
details = f"{x}"
|
||||||
|
if r is not False:
|
||||||
|
details += f" - {str(r)}"
|
||||||
|
|
||||||
|
error = {
|
||||||
|
"type": "custom_validation_failed",
|
||||||
|
"message": "Custom validation failed for node",
|
||||||
|
"details": details,
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
input_config = info
|
||||||
return (True, "")
|
list_info = ""
|
||||||
|
|
||||||
|
# Don't send back gigantic lists like if they're lots of
|
||||||
|
# scanned model filepaths
|
||||||
|
if len(type_input) > 20:
|
||||||
|
list_info = f"(list of length {len(type_input)})"
|
||||||
|
input_config = None
|
||||||
|
else:
|
||||||
|
list_info = str(type_input)
|
||||||
|
|
||||||
|
error = {
|
||||||
|
"type": "value_not_in_list",
|
||||||
|
"message": "Value not in list",
|
||||||
|
"details": f"{x}: '{val}' not in {list_info}",
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": input_config,
|
||||||
|
"received_value": val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(errors) > 0 or valid is not True:
|
||||||
|
ret = (False, errors, unique_id)
|
||||||
|
else:
|
||||||
|
ret = (True, [], unique_id)
|
||||||
|
|
||||||
|
validated[unique_id] = ret
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def full_type_name(klass):
|
||||||
|
module = klass.__module__
|
||||||
|
if module == 'builtins':
|
||||||
|
return klass.__qualname__
|
||||||
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
@ -277,34 +581,86 @@ def validate_prompt(prompt):
|
|||||||
outputs.add(x)
|
outputs.add(x)
|
||||||
|
|
||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
return (False, "Prompt has no outputs")
|
error = {
|
||||||
|
"type": "prompt_no_outputs",
|
||||||
|
"message": "Prompt has no outputs",
|
||||||
|
"details": "",
|
||||||
|
"extra_info": {}
|
||||||
|
}
|
||||||
|
return (False, error, [], [])
|
||||||
|
|
||||||
good_outputs = set()
|
good_outputs = set()
|
||||||
errors = []
|
errors = []
|
||||||
|
node_errors = {}
|
||||||
|
validated = {}
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
valid = False
|
valid = False
|
||||||
reason = ""
|
reasons = []
|
||||||
try:
|
try:
|
||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o, validated)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reasons = m[1]
|
||||||
except Exception as e:
|
except Exception as ex:
|
||||||
print(traceback.format_exc())
|
typ, _, tb = sys.exc_info()
|
||||||
valid = False
|
valid = False
|
||||||
reason = "Parsing error"
|
exception_type = full_type_name(typ)
|
||||||
|
reasons = [{
|
||||||
|
"type": "exception_during_validation",
|
||||||
|
"message": "Exception when validating node",
|
||||||
|
"details": str(ex),
|
||||||
|
"extra_info": {
|
||||||
|
"exception_type": exception_type,
|
||||||
|
"traceback": traceback.format_tb(tb)
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
validated[o] = (False, reasons, o)
|
||||||
|
|
||||||
if valid == True:
|
if valid is True:
|
||||||
good_outputs.add(x)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
print(f"Failed to validate prompt for output {o}:")
|
||||||
print("output will be ignored")
|
if len(reasons) > 0:
|
||||||
errors += [(o, reason)]
|
print("* (prompt):")
|
||||||
|
for reason in reasons:
|
||||||
|
print(f" - {reason['message']}: {reason['details']}")
|
||||||
|
errors += [(o, reasons)]
|
||||||
|
for node_id, result in validated.items():
|
||||||
|
valid = result[0]
|
||||||
|
reasons = result[1]
|
||||||
|
# If a node upstream has errors, the nodes downstream will also
|
||||||
|
# be reported as invalid, but there will be no errors attached.
|
||||||
|
# So don't return those nodes as having errors in the response.
|
||||||
|
if valid is not True and len(reasons) > 0:
|
||||||
|
if node_id not in node_errors:
|
||||||
|
class_type = prompt[node_id]['class_type']
|
||||||
|
node_errors[node_id] = {
|
||||||
|
"errors": reasons,
|
||||||
|
"dependent_outputs": [],
|
||||||
|
"class_type": class_type
|
||||||
|
}
|
||||||
|
print(f"* {class_type} {node_id}:")
|
||||||
|
for reason in reasons:
|
||||||
|
print(f" - {reason['message']}: {reason['details']}")
|
||||||
|
node_errors[node_id]["dependent_outputs"].append(o)
|
||||||
|
print("Output will be ignored")
|
||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
errors_list = []
|
||||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
for o, errors in errors:
|
||||||
|
for error in errors:
|
||||||
|
errors_list.append(f"{error['message']}: {error['details']}")
|
||||||
|
errors_list = "\n".join(errors_list)
|
||||||
|
|
||||||
return (True, "")
|
error = {
|
||||||
|
"type": "prompt_outputs_failed_validation",
|
||||||
|
"message": "Prompt outputs failed validation",
|
||||||
|
"details": errors_list,
|
||||||
|
"extra_info": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (False, error, list(good_outputs), node_errors)
|
||||||
|
|
||||||
|
return (True, None, list(good_outputs), node_errors)
|
||||||
|
|
||||||
|
|
||||||
class PromptQueue:
|
class PromptQueue:
|
||||||
@ -340,8 +696,7 @@ class PromptQueue:
|
|||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
if "ui" in outputs[o]:
|
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||||
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
|
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
|
|||||||
@ -1,14 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
supported_ckpt_extensions = set(['.ckpt', '.pth'])
|
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
|
||||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
|
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||||
try:
|
|
||||||
import safetensors.torch
|
|
||||||
supported_ckpt_extensions.add('.safetensors')
|
|
||||||
supported_pt_extensions.add('.safetensors')
|
|
||||||
except:
|
|
||||||
print("Could not import safetensors, safetensors support disabled.")
|
|
||||||
|
|
||||||
|
|
||||||
folder_names_and_paths = {}
|
folder_names_and_paths = {}
|
||||||
|
|
||||||
@ -38,6 +32,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
|
|||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||||
|
|
||||||
|
filename_list_cache = {}
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
os.makedirs(input_directory)
|
os.makedirs(input_directory)
|
||||||
|
|
||||||
@ -118,12 +114,18 @@ def get_folder_paths(folder_name):
|
|||||||
return folder_names_and_paths[folder_name][0][:]
|
return folder_names_and_paths[folder_name][0][:]
|
||||||
|
|
||||||
def recursive_search(directory):
|
def recursive_search(directory):
|
||||||
|
if not os.path.isdir(directory):
|
||||||
|
return [], {}
|
||||||
result = []
|
result = []
|
||||||
|
dirs = {directory: os.path.getmtime(directory)}
|
||||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||||
for filepath in file:
|
for filepath in file:
|
||||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||||
return result
|
for d in subdir:
|
||||||
|
path = os.path.join(root, d)
|
||||||
|
dirs[path] = os.path.getmtime(path)
|
||||||
|
return result, dirs
|
||||||
|
|
||||||
def filter_files_extensions(files, extensions):
|
def filter_files_extensions(files, extensions):
|
||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||||
@ -132,19 +134,90 @@ def filter_files_extensions(files, extensions):
|
|||||||
|
|
||||||
def get_full_path(folder_name, filename):
|
def get_full_path(folder_name, filename):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
|
if folder_name not in folder_names_and_paths:
|
||||||
|
return None
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
|
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||||
for x in folders[0]:
|
for x in folders[0]:
|
||||||
full_path = os.path.join(x, filename)
|
full_path = os.path.join(x, filename)
|
||||||
if os.path.isfile(full_path):
|
if os.path.isfile(full_path):
|
||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_filename_list(folder_name):
|
def get_filename_list_(folder_name):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
output_list = set()
|
output_list = set()
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
|
output_folders = {}
|
||||||
for x in folders[0]:
|
for x in folders[0]:
|
||||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
files, folders_all = recursive_search(x)
|
||||||
return sorted(list(output_list))
|
output_list.update(filter_files_extensions(files, folders[1]))
|
||||||
|
output_folders = {**output_folders, **folders_all}
|
||||||
|
|
||||||
|
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||||
|
|
||||||
|
def cached_filename_list_(folder_name):
|
||||||
|
global filename_list_cache
|
||||||
|
global folder_names_and_paths
|
||||||
|
if folder_name not in filename_list_cache:
|
||||||
|
return None
|
||||||
|
out = filename_list_cache[folder_name]
|
||||||
|
if time.perf_counter() < (out[2] + 0.5):
|
||||||
|
return out
|
||||||
|
for x in out[1]:
|
||||||
|
time_modified = out[1][x]
|
||||||
|
folder = x
|
||||||
|
if os.path.getmtime(folder) != time_modified:
|
||||||
|
return None
|
||||||
|
|
||||||
|
folders = folder_names_and_paths[folder_name]
|
||||||
|
for x in folders[0]:
|
||||||
|
if os.path.isdir(x):
|
||||||
|
if x not in out[1]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_filename_list(folder_name):
|
||||||
|
out = cached_filename_list_(folder_name)
|
||||||
|
if out is None:
|
||||||
|
out = get_filename_list_(folder_name)
|
||||||
|
global filename_list_cache
|
||||||
|
filename_list_cache[folder_name] = out
|
||||||
|
return list(out[0])
|
||||||
|
|
||||||
|
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||||
|
def map_filename(filename):
|
||||||
|
prefix_len = len(os.path.basename(filename_prefix))
|
||||||
|
prefix = filename[:prefix_len + 1]
|
||||||
|
try:
|
||||||
|
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||||
|
except:
|
||||||
|
digits = 0
|
||||||
|
return (digits, prefix)
|
||||||
|
|
||||||
|
def compute_vars(input, image_width, image_height):
|
||||||
|
input = input.replace("%width%", str(image_width))
|
||||||
|
input = input.replace("%height%", str(image_height))
|
||||||
|
return input
|
||||||
|
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|
||||||
|
full_output_folder = os.path.join(output_dir, subfolder)
|
||||||
|
|
||||||
|
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||||
|
print("Saving image outside the output folder is not allowed.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||||
|
except ValueError:
|
||||||
|
counter = 1
|
||||||
|
except FileNotFoundError:
|
||||||
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
|
counter = 1
|
||||||
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|||||||
4
main.py
4
main.py
@ -33,8 +33,8 @@ def prompt_worker(q, server):
|
|||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
item, item_id = q.get()
|
item, item_id = q.get()
|
||||||
e.execute(item[-2], item[-1])
|
e.execute(item[2], item[1], item[3], item[4])
|
||||||
q.task_done(item_id, e.outputs)
|
q.task_done(item_id, e.outputs_ui)
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||||
|
|||||||
270
nodes.py
270
nodes.py
@ -6,16 +6,18 @@ import json
|
|||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||||
|
|
||||||
|
|
||||||
import comfy.diffusers_convert
|
import comfy.diffusers_load
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
@ -28,6 +30,7 @@ import importlib
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
@ -145,9 +148,6 @@ class ConditioningSetMask:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
@ -160,9 +160,6 @@ class VAEDecode:
|
|||||||
return (vae.decode(samples["samples"]), )
|
return (vae.decode(samples["samples"]), )
|
||||||
|
|
||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
@ -175,9 +172,6 @@ class VAEDecodeTiled:
|
|||||||
return (vae.decode_tiled(samples["samples"]), )
|
return (vae.decode_tiled(samples["samples"]), )
|
||||||
|
|
||||||
class VAEEncode:
|
class VAEEncode:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||||
@ -202,9 +196,6 @@ class VAEEncode:
|
|||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
class VAEEncodeTiled:
|
class VAEEncodeTiled:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||||
@ -219,9 +210,6 @@ class VAEEncodeTiled:
|
|||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
class VAEEncodeForInpaint:
|
class VAEEncodeForInpaint:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
||||||
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
|
|||||||
|
|
||||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLatent:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
# support save metadata for latent sharing
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"prompt": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
output["latent_tensor"] = samples["samples"]
|
||||||
|
|
||||||
|
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||||
|
return {"required": {"latent": [sorted(files), ]}, }
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, latent):
|
||||||
|
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||||
|
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||||
|
samples = {"samples": latent["latent_tensor"].float()}
|
||||||
|
return (samples, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, latent):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(latent)
|
||||||
|
m = hashlib.sha256()
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
m.update(f.read())
|
||||||
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, latent):
|
||||||
|
if not folder_paths.exists_annotated_filepath(latent):
|
||||||
|
return "Invalid latent file: {}".format(latent)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class CheckpointLoader:
|
class CheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -296,7 +359,10 @@ class DiffusersLoader:
|
|||||||
paths = []
|
paths = []
|
||||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
if os.path.exists(search_path):
|
if os.path.exists(search_path):
|
||||||
paths += next(os.walk(search_path))[1]
|
for root, subdir, files in os.walk(search_path, followlinks=True):
|
||||||
|
if "model_index.json" in files:
|
||||||
|
paths.append(os.path.relpath(root, start=search_path))
|
||||||
|
|
||||||
return {"required": {"model_path": (paths,), }}
|
return {"required": {"model_path": (paths,), }}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
@ -306,12 +372,12 @@ class DiffusersLoader:
|
|||||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
if os.path.exists(search_path):
|
if os.path.exists(search_path):
|
||||||
paths = next(os.walk(search_path))[1]
|
path = os.path.join(search_path, model_path)
|
||||||
if model_path in paths:
|
if os.path.exists(path):
|
||||||
model_path = os.path.join(search_path, model_path)
|
model_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
|
|
||||||
class unCLIPCheckpointLoader:
|
class unCLIPCheckpointLoader:
|
||||||
@ -360,6 +426,9 @@ class LoraLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||||
|
if strength_model == 0 and strength_clip == 0:
|
||||||
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
@ -517,9 +586,11 @@ class ControlNetApply:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||||
|
if strength == 0:
|
||||||
|
return (conditioning, )
|
||||||
|
|
||||||
c = []
|
c = []
|
||||||
control_hint = image.movedim(-1,1)
|
control_hint = image.movedim(-1,1)
|
||||||
print(control_hint.shape)
|
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
||||||
@ -624,6 +695,9 @@ class unCLIPConditioning:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
||||||
|
if strength == 0:
|
||||||
|
return (conditioning, )
|
||||||
|
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
o = t[1].copy()
|
o = t[1].copy()
|
||||||
@ -706,22 +780,61 @@ class LatentFromBatch:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return {"required": { "samples": ("LATENT",),
|
||||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||||
|
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "rotate"
|
FUNCTION = "frombatch"
|
||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
def rotate(self, samples, batch_index):
|
def frombatch(self, samples, batch_index, length):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s_in = samples["samples"]
|
s_in = samples["samples"]
|
||||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
length = min(s_in.shape[0] - batch_index, length)
|
||||||
s["batch_index"] = batch_index
|
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||||
|
if "noise_mask" in samples:
|
||||||
|
masks = samples["noise_mask"]
|
||||||
|
if masks.shape[0] == 1:
|
||||||
|
s["noise_mask"] = masks.clone()
|
||||||
|
else:
|
||||||
|
if masks.shape[0] < s_in.shape[0]:
|
||||||
|
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||||
|
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||||
|
if "batch_index" not in s:
|
||||||
|
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||||
|
else:
|
||||||
|
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
class RepeatLatentBatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "repeat"
|
||||||
|
|
||||||
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
|
def repeat(self, samples, amount):
|
||||||
|
s = samples.copy()
|
||||||
|
s_in = samples["samples"]
|
||||||
|
|
||||||
|
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||||
|
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||||
|
masks = samples["noise_mask"]
|
||||||
|
if masks.shape[0] < s_in.shape[0]:
|
||||||
|
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||||
|
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||||
|
if "batch_index" in s:
|
||||||
|
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||||
|
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentUpscale:
|
class LatentUpscale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -740,6 +853,25 @@ class LatentUpscale:
|
|||||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
class LatentUpscaleBy:
|
||||||
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
|
||||||
|
"scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "upscale"
|
||||||
|
|
||||||
|
CATEGORY = "latent"
|
||||||
|
|
||||||
|
def upscale(self, samples, upscale_method, scale_by):
|
||||||
|
s = samples.copy()
|
||||||
|
width = round(samples["samples"].shape[3] * scale_by)
|
||||||
|
height = round(samples["samples"].shape[2] * scale_by)
|
||||||
|
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||||
|
return (s,)
|
||||||
|
|
||||||
class LatentRotate:
|
class LatentRotate:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -872,7 +1004,7 @@ class SetLatentNoiseMask:
|
|||||||
|
|
||||||
def set_mask(self, samples, mask):
|
def set_mask(self, samples, mask):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s["noise_mask"] = mask
|
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||||
@ -882,8 +1014,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if disable_noise:
|
if disable_noise:
|
||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
else:
|
else:
|
||||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
@ -978,39 +1110,7 @@ class SaveImage:
|
|||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
def map_filename(filename):
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
prefix_len = len(os.path.basename(filename_prefix))
|
|
||||||
prefix = filename[:prefix_len + 1]
|
|
||||||
try:
|
|
||||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
|
||||||
except:
|
|
||||||
digits = 0
|
|
||||||
return (digits, prefix)
|
|
||||||
|
|
||||||
def compute_vars(input):
|
|
||||||
input = input.replace("%width%", str(images[0].shape[1]))
|
|
||||||
input = input.replace("%height%", str(images[0].shape[0]))
|
|
||||||
return input
|
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix)
|
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
|
||||||
|
|
||||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
|
||||||
|
|
||||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
|
||||||
print("Saving image outside the output folder is not allowed.")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
|
||||||
except ValueError:
|
|
||||||
counter = 1
|
|
||||||
except FileNotFoundError:
|
|
||||||
os.makedirs(full_output_folder, exist_ok=True)
|
|
||||||
counter = 1
|
|
||||||
|
|
||||||
results = list()
|
results = list()
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
@ -1049,8 +1149,9 @@ class LoadImage:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(input_dir)), )},
|
{"image": (sorted(files), )},
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
@ -1060,6 +1161,7 @@ class LoadImage:
|
|||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
|
i = ImageOps.exif_transpose(i)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = torch.from_numpy(image)[None,]
|
image = torch.from_numpy(image)[None,]
|
||||||
@ -1090,9 +1192,10 @@ class LoadImageMask:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(input_dir)), ),
|
{"image": (sorted(files), ),
|
||||||
"channel": (s._color_channels, ),}
|
"channel": (s._color_channels, ), }
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask"
|
CATEGORY = "mask"
|
||||||
@ -1102,6 +1205,7 @@ class LoadImageMask:
|
|||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
|
i = ImageOps.exif_transpose(i)
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
if i.getbands() != ("R", "G", "B", "A"):
|
||||||
i = i.convert("RGBA")
|
i = i.convert("RGBA")
|
||||||
mask = None
|
mask = None
|
||||||
@ -1244,7 +1348,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"VAELoader": VAELoader,
|
"VAELoader": VAELoader,
|
||||||
"EmptyLatentImage": EmptyLatentImage,
|
"EmptyLatentImage": EmptyLatentImage,
|
||||||
"LatentUpscale": LatentUpscale,
|
"LatentUpscale": LatentUpscale,
|
||||||
|
"LatentUpscaleBy": LatentUpscaleBy,
|
||||||
"LatentFromBatch": LatentFromBatch,
|
"LatentFromBatch": LatentFromBatch,
|
||||||
|
"RepeatLatentBatch": RepeatLatentBatch,
|
||||||
"SaveImage": SaveImage,
|
"SaveImage": SaveImage,
|
||||||
"PreviewImage": PreviewImage,
|
"PreviewImage": PreviewImage,
|
||||||
"LoadImage": LoadImage,
|
"LoadImage": LoadImage,
|
||||||
@ -1282,6 +1388,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
"DiffusersLoader": DiffusersLoader,
|
"DiffusersLoader": DiffusersLoader,
|
||||||
|
|
||||||
|
"LoadLatent": LoadLatent,
|
||||||
|
"SaveLatent": SaveLatent
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -1319,7 +1428,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LatentCrop": "Crop Latent",
|
"LatentCrop": "Crop Latent",
|
||||||
"EmptyLatentImage": "Empty Latent Image",
|
"EmptyLatentImage": "Empty Latent Image",
|
||||||
"LatentUpscale": "Upscale Latent",
|
"LatentUpscale": "Upscale Latent",
|
||||||
|
"LatentUpscaleBy": "Upscale Latent By",
|
||||||
"LatentComposite": "Latent Composite",
|
"LatentComposite": "Latent Composite",
|
||||||
|
"LatentFromBatch" : "Latent From Batch",
|
||||||
|
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||||
# Image
|
# Image
|
||||||
"SaveImage": "Save Image",
|
"SaveImage": "Save Image",
|
||||||
"PreviewImage": "Preview Image",
|
"PreviewImage": "Preview Image",
|
||||||
@ -1351,14 +1463,18 @@ def load_custom_node(module_path):
|
|||||||
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||||
|
return False
|
||||||
|
|
||||||
def load_custom_nodes():
|
def load_custom_nodes():
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
|
node_import_times = []
|
||||||
for custom_node_path in node_paths:
|
for custom_node_path in node_paths:
|
||||||
possible_modules = os.listdir(custom_node_path)
|
possible_modules = os.listdir(custom_node_path)
|
||||||
if "__pycache__" in possible_modules:
|
if "__pycache__" in possible_modules:
|
||||||
@ -1367,11 +1483,25 @@ def load_custom_nodes():
|
|||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||||
load_custom_node(module_path)
|
if module_path.endswith(".disabled"): continue
|
||||||
|
time_before = time.perf_counter()
|
||||||
|
success = load_custom_node(module_path)
|
||||||
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
|
|
||||||
|
if len(node_import_times) > 0:
|
||||||
|
print("\nImport times for custom nodes:")
|
||||||
|
for n in sorted(node_import_times):
|
||||||
|
if n[2]:
|
||||||
|
import_message = ""
|
||||||
|
else:
|
||||||
|
import_message = " (IMPORT FAILED)"
|
||||||
|
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||||
|
print()
|
||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_nodes()
|
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
|
load_custom_nodes()
|
||||||
|
|||||||
@ -175,6 +175,8 @@
|
|||||||
"import threading\n",
|
"import threading\n",
|
||||||
"import time\n",
|
"import time\n",
|
||||||
"import socket\n",
|
"import socket\n",
|
||||||
|
"import urllib.request\n",
|
||||||
|
"\n",
|
||||||
"def iframe_thread(port):\n",
|
"def iframe_thread(port):\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" time.sleep(0.5)\n",
|
" time.sleep(0.5)\n",
|
||||||
@ -183,7 +185,9 @@
|
|||||||
" if result == 0:\n",
|
" if result == 0:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" sock.close()\n",
|
" sock.close()\n",
|
||||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||||
" for line in p.stdout:\n",
|
" for line in p.stdout:\n",
|
||||||
" print(line.decode(), end='')\n",
|
" print(line.decode(), end='')\n",
|
||||||
|
|||||||
250
server.py
250
server.py
@ -7,6 +7,9 @@ import execution
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -19,7 +22,8 @@ except ImportError:
|
|||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
@ -78,7 +82,7 @@ class PromptServer():
|
|||||||
# Reusing existing session, remove old
|
# Reusing existing session, remove old
|
||||||
self.sockets.pop(sid, None)
|
self.sockets.pop(sid, None)
|
||||||
else:
|
else:
|
||||||
sid = uuid.uuid4().hex
|
sid = uuid.uuid4().hex
|
||||||
|
|
||||||
self.sockets[sid] = ws
|
self.sockets[sid] = ws
|
||||||
|
|
||||||
@ -110,49 +114,96 @@ class PromptServer():
|
|||||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
||||||
|
|
||||||
@routes.post("/upload/image")
|
def get_dir_by_type(dir_type):
|
||||||
async def upload_image(request):
|
if dir_type is None:
|
||||||
post = await request.post()
|
dir_type = "input"
|
||||||
|
|
||||||
|
if dir_type == "input":
|
||||||
|
type_dir = folder_paths.get_input_directory()
|
||||||
|
elif dir_type == "temp":
|
||||||
|
type_dir = folder_paths.get_temp_directory()
|
||||||
|
elif dir_type == "output":
|
||||||
|
type_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
return type_dir, dir_type
|
||||||
|
|
||||||
|
def image_upload(post, image_save_function=None):
|
||||||
image = post.get("image")
|
image = post.get("image")
|
||||||
|
overwrite = post.get("overwrite")
|
||||||
|
|
||||||
if post.get("type") is None:
|
image_upload_type = post.get("type")
|
||||||
upload_dir = folder_paths.get_input_directory()
|
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||||||
elif post.get("type") == "input":
|
|
||||||
upload_dir = folder_paths.get_input_directory()
|
|
||||||
elif post.get("type") == "temp":
|
|
||||||
upload_dir = folder_paths.get_temp_directory()
|
|
||||||
elif post.get("type") == "output":
|
|
||||||
upload_dir = folder_paths.get_output_directory()
|
|
||||||
|
|
||||||
if not os.path.exists(upload_dir):
|
|
||||||
os.makedirs(upload_dir)
|
|
||||||
|
|
||||||
if image and image.file:
|
if image and image.file:
|
||||||
filename = image.filename
|
filename = image.filename
|
||||||
if not filename:
|
if not filename:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
subfolder = post.get("subfolder", "")
|
||||||
|
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
|
||||||
|
|
||||||
|
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
if not os.path.exists(full_output_folder):
|
||||||
|
os.makedirs(full_output_folder)
|
||||||
|
|
||||||
split = os.path.splitext(filename)
|
split = os.path.splitext(filename)
|
||||||
i = 1
|
filepath = os.path.join(full_output_folder, filename)
|
||||||
while os.path.exists(os.path.join(upload_dir, filename)):
|
|
||||||
filename = f"{split[0]} ({i}){split[1]}"
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
filepath = os.path.join(upload_dir, filename)
|
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
i = 1
|
||||||
|
while os.path.exists(filepath):
|
||||||
|
filename = f"{split[0]} ({i}){split[1]}"
|
||||||
|
filepath = os.path.join(full_output_folder, filename)
|
||||||
|
i += 1
|
||||||
|
|
||||||
with open(filepath, "wb") as f:
|
if image_save_function is not None:
|
||||||
f.write(image.file.read())
|
image_save_function(image, post, filepath)
|
||||||
|
else:
|
||||||
return web.json_response({"name" : filename})
|
with open(filepath, "wb") as f:
|
||||||
|
f.write(image.file.read())
|
||||||
|
|
||||||
|
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||||
else:
|
else:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
@routes.post("/upload/image")
|
||||||
|
async def upload_image(request):
|
||||||
|
post = await request.post()
|
||||||
|
return image_upload(post)
|
||||||
|
|
||||||
|
@routes.post("/upload/mask")
|
||||||
|
async def upload_mask(request):
|
||||||
|
post = await request.post()
|
||||||
|
|
||||||
|
def image_save_function(image, post, filepath):
|
||||||
|
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
||||||
|
mask_pil = Image.open(image.file).convert('RGBA')
|
||||||
|
|
||||||
|
# alpha copy
|
||||||
|
new_alpha = mask_pil.getchannel('A')
|
||||||
|
original_pil.putalpha(new_alpha)
|
||||||
|
original_pil.save(filepath, compress_level=4)
|
||||||
|
|
||||||
|
return image_upload(post, image_save_function)
|
||||||
|
|
||||||
@routes.get("/view")
|
@routes.get("/view")
|
||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
type = request.rel_url.query.get("type", "output")
|
filename = request.rel_url.query["filename"]
|
||||||
output_dir = folder_paths.get_directory_by_type(type)
|
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
|
# validation for security: prevent accessing arbitrary path
|
||||||
|
if filename[0] == '/' or '..' in filename:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
type = request.rel_url.query.get("type", "output")
|
||||||
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
@ -162,35 +213,132 @@ class PromptServer():
|
|||||||
return web.Response(status=403)
|
return web.Response(status=403)
|
||||||
output_dir = full_output_dir
|
output_dir = full_output_dir
|
||||||
|
|
||||||
filename = request.rel_url.query["filename"]
|
|
||||||
filename = os.path.basename(filename)
|
filename = os.path.basename(filename)
|
||||||
file = os.path.join(output_dir, filename)
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
if 'channel' not in request.rel_url.query:
|
||||||
|
channel = 'rgba'
|
||||||
|
else:
|
||||||
|
channel = request.rel_url.query["channel"]
|
||||||
|
|
||||||
|
if channel == 'rgb':
|
||||||
|
with Image.open(file) as img:
|
||||||
|
if img.mode == "RGBA":
|
||||||
|
r, g, b, a = img.split()
|
||||||
|
new_img = Image.merge('RGB', (r, g, b))
|
||||||
|
else:
|
||||||
|
new_img = img.convert("RGB")
|
||||||
|
|
||||||
|
buffer = BytesIO()
|
||||||
|
new_img.save(buffer, format='PNG')
|
||||||
|
buffer.seek(0)
|
||||||
|
|
||||||
|
return web.Response(body=buffer.read(), content_type='image/png',
|
||||||
|
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
|
|
||||||
|
elif channel == 'a':
|
||||||
|
with Image.open(file) as img:
|
||||||
|
if img.mode == "RGBA":
|
||||||
|
_, _, _, a = img.split()
|
||||||
|
else:
|
||||||
|
a = Image.new('L', img.size, 255)
|
||||||
|
|
||||||
|
# alpha img
|
||||||
|
alpha_img = Image.new('RGBA', img.size)
|
||||||
|
alpha_img.putalpha(a)
|
||||||
|
alpha_buffer = BytesIO()
|
||||||
|
alpha_img.save(alpha_buffer, format='PNG')
|
||||||
|
alpha_buffer.seek(0)
|
||||||
|
|
||||||
|
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||||
|
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
|
else:
|
||||||
|
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
|
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
@routes.get("/view_metadata/{folder_name}")
|
||||||
|
async def view_metadata(request):
|
||||||
|
folder_name = request.match_info.get("folder_name", None)
|
||||||
|
if folder_name is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
if not "filename" in request.rel_url.query:
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
filename = request.rel_url.query["filename"]
|
||||||
|
if not filename.endswith(".safetensors"):
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||||||
|
if safetensors_path is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||||||
|
if out is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
dt = json.loads(out)
|
||||||
|
if not "__metadata__" in dt:
|
||||||
|
return web.Response(status=404)
|
||||||
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
|
@routes.get("/system_stats")
|
||||||
|
async def get_queue(request):
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||||
|
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||||
|
system_stats = {
|
||||||
|
"devices": [
|
||||||
|
{
|
||||||
|
"name": device_name,
|
||||||
|
"type": device.type,
|
||||||
|
"index": device.index,
|
||||||
|
"vram_total": vram_total,
|
||||||
|
"vram_free": vram_free,
|
||||||
|
"torch_vram_total": torch_vram_total,
|
||||||
|
"torch_vram_free": torch_vram_free,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return web.json_response(system_stats)
|
||||||
|
|
||||||
@routes.get("/prompt")
|
@routes.get("/prompt")
|
||||||
async def get_prompt(request):
|
async def get_prompt(request):
|
||||||
return web.json_response(self.get_queue_info())
|
return web.json_response(self.get_queue_info())
|
||||||
|
|
||||||
|
def node_info(node_class):
|
||||||
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
info = {}
|
||||||
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||||
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
|
info['name'] = node_class
|
||||||
|
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||||
|
info['description'] = ''
|
||||||
|
info['category'] = 'sd'
|
||||||
|
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||||
|
info['output_node'] = True
|
||||||
|
else:
|
||||||
|
info['output_node'] = False
|
||||||
|
|
||||||
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
|
info['category'] = obj_class.CATEGORY
|
||||||
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
out[x] = node_info(x)
|
||||||
info = {}
|
return web.json_response(out)
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
@routes.get("/object_info/{node_class}")
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
async def get_object_info_node(request):
|
||||||
info['name'] = x
|
node_class = request.match_info.get("node_class", None)
|
||||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
out = {}
|
||||||
info['description'] = ''
|
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||||
info['category'] = 'sd'
|
out[node_class] = node_info(node_class)
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
|
||||||
info['category'] = obj_class.CATEGORY
|
|
||||||
out[x] = info
|
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
@ -232,14 +380,16 @@ class PromptServer():
|
|||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
prompt_id = str(uuid.uuid4())
|
||||||
|
outputs_to_execute = valid[2]
|
||||||
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||||
|
return web.json_response({"prompt_id": prompt_id, "number": number})
|
||||||
else:
|
else:
|
||||||
resp_code = 400
|
|
||||||
out_string = valid[1]
|
|
||||||
print("invalid prompt:", valid[1])
|
print("invalid prompt:", valid[1])
|
||||||
|
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||||
|
|
||||||
return web.Response(body=out_string, status=resp_code)
|
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
@ -249,9 +399,9 @@ class PromptServer():
|
|||||||
if "delete" in json_data:
|
if "delete" in json_data:
|
||||||
to_delete = json_data['delete']
|
to_delete = json_data['delete']
|
||||||
for id_to_delete in to_delete:
|
for id_to_delete in to_delete:
|
||||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
delete_func = lambda a: a[1] == id_to_delete
|
||||||
self.prompt_queue.delete_queue_item(delete_func)
|
self.prompt_queue.delete_queue_item(delete_func)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/interrupt")
|
@routes.post("/interrupt")
|
||||||
@ -275,7 +425,7 @@ class PromptServer():
|
|||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root, follow_symlinks=True),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get_queue_info(self):
|
def get_queue_info(self):
|
||||||
|
|||||||
166
web/extensions/core/clipspace.js
Normal file
166
web/extensions/core/clipspace.js
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||||
|
import { ComfyApp } from "/scripts/app.js";
|
||||||
|
|
||||||
|
export class ClipspaceDialog extends ComfyDialog {
|
||||||
|
static items = [];
|
||||||
|
static instance = null;
|
||||||
|
|
||||||
|
static registerButton(name, contextPredicate, callback) {
|
||||||
|
const item =
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: name,
|
||||||
|
contextPredicate: contextPredicate,
|
||||||
|
onclick: callback
|
||||||
|
})
|
||||||
|
|
||||||
|
ClipspaceDialog.items.push(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
static invalidatePreview() {
|
||||||
|
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
|
||||||
|
const img_preview = document.getElementById("clipspace_preview");
|
||||||
|
if(img_preview) {
|
||||||
|
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||||
|
img_preview.style.maxHeight = "100%";
|
||||||
|
img_preview.style.maxWidth = "100%";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static invalidate() {
|
||||||
|
if(ClipspaceDialog.instance) {
|
||||||
|
const self = ClipspaceDialog.instance;
|
||||||
|
// allow reconstruct controls when copying from non-image to image content.
|
||||||
|
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
|
||||||
|
|
||||||
|
if(self.element) {
|
||||||
|
// update
|
||||||
|
self.element.removeChild(self.element.firstChild);
|
||||||
|
self.element.appendChild(children);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// new
|
||||||
|
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(self.element.children[0].children.length <= 1) {
|
||||||
|
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
|
||||||
|
}
|
||||||
|
|
||||||
|
ClipspaceDialog.invalidatePreview();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
createButtons(self) {
|
||||||
|
const buttons = [];
|
||||||
|
|
||||||
|
for(let idx in ClipspaceDialog.items) {
|
||||||
|
const item = ClipspaceDialog.items[idx];
|
||||||
|
if(!item.contextPredicate || item.contextPredicate())
|
||||||
|
buttons.push(ClipspaceDialog.items[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
buttons.push(
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Close",
|
||||||
|
onclick: () => { this.close(); }
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
return buttons;
|
||||||
|
}
|
||||||
|
|
||||||
|
createImgSettings() {
|
||||||
|
if(ComfyApp.clipspace.imgs) {
|
||||||
|
const combo_items = [];
|
||||||
|
const imgs = ComfyApp.clipspace.imgs;
|
||||||
|
|
||||||
|
for(let i=0; i < imgs.length; i++) {
|
||||||
|
combo_items.push($el("option", {value:i}, [`${i}`]));
|
||||||
|
}
|
||||||
|
|
||||||
|
const combo1 = $el("select",
|
||||||
|
{id:"clipspace_img_selector", onchange:(event) => {
|
||||||
|
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
|
||||||
|
ClipspaceDialog.invalidatePreview();
|
||||||
|
} }, combo_items);
|
||||||
|
|
||||||
|
const row1 =
|
||||||
|
$el("tr", {},
|
||||||
|
[
|
||||||
|
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
|
||||||
|
$el("td", {}, [combo1])
|
||||||
|
]);
|
||||||
|
|
||||||
|
|
||||||
|
const combo2 = $el("select",
|
||||||
|
{id:"clipspace_img_paste_mode", onchange:(event) => {
|
||||||
|
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
|
||||||
|
} },
|
||||||
|
[
|
||||||
|
$el("option", {value:'selected'}, 'selected'),
|
||||||
|
$el("option", {value:'all'}, 'all')
|
||||||
|
]);
|
||||||
|
combo2.value = ComfyApp.clipspace['img_paste_mode'];
|
||||||
|
|
||||||
|
const row2 =
|
||||||
|
$el("tr", {},
|
||||||
|
[
|
||||||
|
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
|
||||||
|
$el("td", {}, [combo2])
|
||||||
|
]);
|
||||||
|
|
||||||
|
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
|
||||||
|
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
|
||||||
|
|
||||||
|
const row3 =
|
||||||
|
$el("tr", {}, [td]);
|
||||||
|
|
||||||
|
return $el("table", {}, [row1, row2, row3]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createImgPreview() {
|
||||||
|
if(ComfyApp.clipspace.imgs) {
|
||||||
|
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
show() {
|
||||||
|
const img_preview = document.getElementById("clipspace_preview");
|
||||||
|
ClipspaceDialog.invalidate();
|
||||||
|
|
||||||
|
this.element.style.display = "block";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.Clipspace",
|
||||||
|
init(app) {
|
||||||
|
app.openClipspace =
|
||||||
|
function () {
|
||||||
|
if(!ClipspaceDialog.instance) {
|
||||||
|
ClipspaceDialog.instance = new ClipspaceDialog(app);
|
||||||
|
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace) {
|
||||||
|
ClipspaceDialog.instance.show();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
app.ui.dialog.show("Clipspace is Empty!");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
@ -174,7 +174,7 @@ const els = {}
|
|||||||
// const ctxMenu = LiteGraph.ContextMenu;
|
// const ctxMenu = LiteGraph.ContextMenu;
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: id,
|
name: id,
|
||||||
init() {
|
addCustomNodeDefs(node_defs) {
|
||||||
const sortObjectKeys = (unordered) => {
|
const sortObjectKeys = (unordered) => {
|
||||||
return Object.keys(unordered).sort().reduce((obj, key) => {
|
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||||
obj[key] = unordered[key];
|
obj[key] = unordered[key];
|
||||||
@ -182,10 +182,10 @@ app.registerExtension({
|
|||||||
}, {});
|
}, {});
|
||||||
};
|
};
|
||||||
|
|
||||||
const getSlotTypes = async () => {
|
function getSlotTypes() {
|
||||||
var types = [];
|
var types = [];
|
||||||
|
|
||||||
const defs = await api.getNodeDefs();
|
const defs = node_defs;
|
||||||
for (const nodeId in defs) {
|
for (const nodeId in defs) {
|
||||||
const nodeData = defs[nodeId];
|
const nodeData = defs[nodeId];
|
||||||
|
|
||||||
@ -212,8 +212,8 @@ app.registerExtension({
|
|||||||
return types;
|
return types;
|
||||||
};
|
};
|
||||||
|
|
||||||
const completeColorPalette = async (colorPalette) => {
|
function completeColorPalette(colorPalette) {
|
||||||
var types = await getSlotTypes();
|
var types = getSlotTypes();
|
||||||
|
|
||||||
for (const type of types) {
|
for (const type of types) {
|
||||||
if (!colorPalette.colors.node_slot[type]) {
|
if (!colorPalette.colors.node_slot[type]) {
|
||||||
|
|||||||
648
web/extensions/core/maskeditor.js
Normal file
648
web/extensions/core/maskeditor.js
Normal file
@ -0,0 +1,648 @@
|
|||||||
|
import { app } from "/scripts/app.js";
|
||||||
|
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||||
|
import { ComfyApp } from "/scripts/app.js";
|
||||||
|
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
|
||||||
|
|
||||||
|
// Helper function to convert a data URL to a Blob object
|
||||||
|
function dataURLToBlob(dataURL) {
|
||||||
|
const parts = dataURL.split(';base64,');
|
||||||
|
const contentType = parts[0].split(':')[1];
|
||||||
|
const byteString = atob(parts[1]);
|
||||||
|
const arrayBuffer = new ArrayBuffer(byteString.length);
|
||||||
|
const uint8Array = new Uint8Array(arrayBuffer);
|
||||||
|
for (let i = 0; i < byteString.length; i++) {
|
||||||
|
uint8Array[i] = byteString.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return new Blob([arrayBuffer], { type: contentType });
|
||||||
|
}
|
||||||
|
|
||||||
|
function loadedImageToBlob(image) {
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
|
||||||
|
canvas.width = image.width;
|
||||||
|
canvas.height = image.height;
|
||||||
|
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
|
||||||
|
ctx.drawImage(image, 0, 0);
|
||||||
|
|
||||||
|
const dataURL = canvas.toDataURL('image/png', 1);
|
||||||
|
const blob = dataURLToBlob(dataURL);
|
||||||
|
|
||||||
|
return blob;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function uploadMask(filepath, formData) {
|
||||||
|
await fetch('/upload/mask', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
}).then(response => {}).catch(error => {
|
||||||
|
console.error('Error:', error);
|
||||||
|
});
|
||||||
|
|
||||||
|
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
|
||||||
|
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.images)
|
||||||
|
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
|
||||||
|
|
||||||
|
ClipspaceDialog.invalidatePreview();
|
||||||
|
}
|
||||||
|
|
||||||
|
function prepareRGB(image, backupCanvas, backupCtx) {
|
||||||
|
// paste mask data into alpha channel
|
||||||
|
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
|
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
|
|
||||||
|
// refine mask image
|
||||||
|
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||||
|
if(backupData.data[i+3] == 255)
|
||||||
|
backupData.data[i+3] = 0;
|
||||||
|
else
|
||||||
|
backupData.data[i+3] = 255;
|
||||||
|
|
||||||
|
backupData.data[i] = 0;
|
||||||
|
backupData.data[i+1] = 0;
|
||||||
|
backupData.data[i+2] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
backupCtx.globalCompositeOperation = 'source-over';
|
||||||
|
backupCtx.putImageData(backupData, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
class MaskEditorDialog extends ComfyDialog {
|
||||||
|
static instance = null;
|
||||||
|
|
||||||
|
static getInstance() {
|
||||||
|
if(!MaskEditorDialog.instance) {
|
||||||
|
MaskEditorDialog.instance = new MaskEditorDialog(app);
|
||||||
|
}
|
||||||
|
|
||||||
|
return MaskEditorDialog.instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
is_layout_created = false;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
this.element = $el("div.comfy-modal", { parent: document.body },
|
||||||
|
[ $el("div.comfy-modal-content",
|
||||||
|
[...this.createButtons()]),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
createButtons() {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
createButton(name, callback) {
|
||||||
|
var button = document.createElement("button");
|
||||||
|
button.innerText = name;
|
||||||
|
button.addEventListener("click", callback);
|
||||||
|
return button;
|
||||||
|
}
|
||||||
|
|
||||||
|
createLeftButton(name, callback) {
|
||||||
|
var button = this.createButton(name, callback);
|
||||||
|
button.style.cssFloat = "left";
|
||||||
|
button.style.marginRight = "4px";
|
||||||
|
return button;
|
||||||
|
}
|
||||||
|
|
||||||
|
createRightButton(name, callback) {
|
||||||
|
var button = this.createButton(name, callback);
|
||||||
|
button.style.cssFloat = "right";
|
||||||
|
button.style.marginLeft = "4px";
|
||||||
|
return button;
|
||||||
|
}
|
||||||
|
|
||||||
|
createLeftSlider(self, name, callback) {
|
||||||
|
const divElement = document.createElement('div');
|
||||||
|
divElement.id = "maskeditor-slider";
|
||||||
|
divElement.style.cssFloat = "left";
|
||||||
|
divElement.style.fontFamily = "sans-serif";
|
||||||
|
divElement.style.marginRight = "4px";
|
||||||
|
divElement.style.color = "var(--input-text)";
|
||||||
|
divElement.style.backgroundColor = "var(--comfy-input-bg)";
|
||||||
|
divElement.style.borderRadius = "8px";
|
||||||
|
divElement.style.borderColor = "var(--border-color)";
|
||||||
|
divElement.style.borderStyle = "solid";
|
||||||
|
divElement.style.fontSize = "15px";
|
||||||
|
divElement.style.height = "21px";
|
||||||
|
divElement.style.padding = "1px 6px";
|
||||||
|
divElement.style.display = "flex";
|
||||||
|
divElement.style.position = "relative";
|
||||||
|
divElement.style.top = "2px";
|
||||||
|
self.brush_slider_input = document.createElement('input');
|
||||||
|
self.brush_slider_input.setAttribute('type', 'range');
|
||||||
|
self.brush_slider_input.setAttribute('min', '1');
|
||||||
|
self.brush_slider_input.setAttribute('max', '100');
|
||||||
|
self.brush_slider_input.setAttribute('value', '10');
|
||||||
|
const labelElement = document.createElement("label");
|
||||||
|
labelElement.textContent = name;
|
||||||
|
|
||||||
|
divElement.appendChild(labelElement);
|
||||||
|
divElement.appendChild(self.brush_slider_input);
|
||||||
|
|
||||||
|
self.brush_slider_input.addEventListener("change", callback);
|
||||||
|
|
||||||
|
return divElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
setlayout(imgCanvas, maskCanvas) {
|
||||||
|
const self = this;
|
||||||
|
|
||||||
|
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||||
|
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||||
|
var placeholder = document.createElement("div");
|
||||||
|
placeholder.style.position = "relative";
|
||||||
|
placeholder.style.height = "50px";
|
||||||
|
|
||||||
|
var bottom_panel = document.createElement("div");
|
||||||
|
bottom_panel.style.position = "absolute";
|
||||||
|
bottom_panel.style.bottom = "0px";
|
||||||
|
bottom_panel.style.left = "20px";
|
||||||
|
bottom_panel.style.right = "20px";
|
||||||
|
bottom_panel.style.height = "50px";
|
||||||
|
|
||||||
|
var brush = document.createElement("div");
|
||||||
|
brush.id = "brush";
|
||||||
|
brush.style.backgroundColor = "transparent";
|
||||||
|
brush.style.outline = "1px dashed black";
|
||||||
|
brush.style.boxShadow = "0 0 0 1px white";
|
||||||
|
brush.style.borderRadius = "50%";
|
||||||
|
brush.style.MozBorderRadius = "50%";
|
||||||
|
brush.style.WebkitBorderRadius = "50%";
|
||||||
|
brush.style.position = "absolute";
|
||||||
|
brush.style.zIndex = 8889;
|
||||||
|
brush.style.pointerEvents = "none";
|
||||||
|
this.brush = brush;
|
||||||
|
this.element.appendChild(imgCanvas);
|
||||||
|
this.element.appendChild(maskCanvas);
|
||||||
|
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||||
|
this.element.appendChild(bottom_panel);
|
||||||
|
document.body.appendChild(brush);
|
||||||
|
|
||||||
|
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||||
|
self.brush_size = event.target.value;
|
||||||
|
self.updateBrushPreview(self, null, null);
|
||||||
|
});
|
||||||
|
var clearButton = this.createLeftButton("Clear",
|
||||||
|
() => {
|
||||||
|
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||||
|
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
||||||
|
});
|
||||||
|
var cancelButton = this.createRightButton("Cancel", () => {
|
||||||
|
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||||
|
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||||
|
self.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
this.saveButton = this.createRightButton("Save", () => {
|
||||||
|
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||||
|
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||||
|
self.save();
|
||||||
|
});
|
||||||
|
|
||||||
|
this.element.appendChild(imgCanvas);
|
||||||
|
this.element.appendChild(maskCanvas);
|
||||||
|
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||||
|
this.element.appendChild(bottom_panel);
|
||||||
|
|
||||||
|
bottom_panel.appendChild(clearButton);
|
||||||
|
bottom_panel.appendChild(this.saveButton);
|
||||||
|
bottom_panel.appendChild(cancelButton);
|
||||||
|
bottom_panel.appendChild(brush_size_slider);
|
||||||
|
|
||||||
|
imgCanvas.style.position = "relative";
|
||||||
|
imgCanvas.style.top = "200";
|
||||||
|
imgCanvas.style.left = "0";
|
||||||
|
|
||||||
|
maskCanvas.style.position = "absolute";
|
||||||
|
}
|
||||||
|
|
||||||
|
show() {
|
||||||
|
if(!this.is_layout_created) {
|
||||||
|
// layout
|
||||||
|
const imgCanvas = document.createElement('canvas');
|
||||||
|
const maskCanvas = document.createElement('canvas');
|
||||||
|
const backupCanvas = document.createElement('canvas');
|
||||||
|
|
||||||
|
imgCanvas.id = "imageCanvas";
|
||||||
|
maskCanvas.id = "maskCanvas";
|
||||||
|
backupCanvas.id = "backupCanvas";
|
||||||
|
|
||||||
|
this.setlayout(imgCanvas, maskCanvas);
|
||||||
|
|
||||||
|
// prepare content
|
||||||
|
this.imgCanvas = imgCanvas;
|
||||||
|
this.maskCanvas = maskCanvas;
|
||||||
|
this.backupCanvas = backupCanvas;
|
||||||
|
this.maskCtx = maskCanvas.getContext('2d');
|
||||||
|
this.backupCtx = backupCanvas.getContext('2d');
|
||||||
|
|
||||||
|
this.setEventHandler(maskCanvas);
|
||||||
|
|
||||||
|
this.is_layout_created = true;
|
||||||
|
|
||||||
|
// replacement of onClose hook since close is not real close
|
||||||
|
const self = this;
|
||||||
|
const observer = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutation) {
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||||
|
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||||
|
ComfyApp.onClipspaceEditorClosed();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.last_display_style = self.element.style.display;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const config = { attributes: true };
|
||||||
|
observer.observe(this.element, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace_return_node) {
|
||||||
|
this.saveButton.innerText = "Save to node";
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.saveButton.innerText = "Save";
|
||||||
|
}
|
||||||
|
this.saveButton.disabled = false;
|
||||||
|
|
||||||
|
this.element.style.display = "block";
|
||||||
|
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||||
|
}
|
||||||
|
|
||||||
|
isOpened() {
|
||||||
|
return this.element.style.display == "block";
|
||||||
|
}
|
||||||
|
|
||||||
|
setImages(imgCanvas, backupCanvas) {
|
||||||
|
const imgCtx = imgCanvas.getContext('2d');
|
||||||
|
const backupCtx = backupCanvas.getContext('2d');
|
||||||
|
const maskCtx = this.maskCtx;
|
||||||
|
const maskCanvas = this.maskCanvas;
|
||||||
|
|
||||||
|
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||||
|
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||||
|
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||||
|
|
||||||
|
// image load
|
||||||
|
const orig_image = new Image();
|
||||||
|
window.addEventListener("resize", () => {
|
||||||
|
// repositioning
|
||||||
|
imgCanvas.width = window.innerWidth - 250;
|
||||||
|
imgCanvas.height = window.innerHeight - 200;
|
||||||
|
|
||||||
|
// redraw image
|
||||||
|
let drawWidth = orig_image.width;
|
||||||
|
let drawHeight = orig_image.height;
|
||||||
|
if (orig_image.width > imgCanvas.width) {
|
||||||
|
drawWidth = imgCanvas.width;
|
||||||
|
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drawHeight > imgCanvas.height) {
|
||||||
|
drawHeight = imgCanvas.height;
|
||||||
|
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
||||||
|
}
|
||||||
|
|
||||||
|
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||||
|
|
||||||
|
// update mask
|
||||||
|
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
|
maskCanvas.width = drawWidth;
|
||||||
|
maskCanvas.height = drawHeight;
|
||||||
|
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||||
|
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||||
|
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||||
|
});
|
||||||
|
|
||||||
|
const filepath = ComfyApp.clipspace.images;
|
||||||
|
|
||||||
|
const touched_image = new Image();
|
||||||
|
|
||||||
|
touched_image.onload = function() {
|
||||||
|
backupCanvas.width = touched_image.width;
|
||||||
|
backupCanvas.height = touched_image.height;
|
||||||
|
|
||||||
|
prepareRGB(touched_image, backupCanvas, backupCtx);
|
||||||
|
};
|
||||||
|
|
||||||
|
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||||
|
alpha_url.searchParams.delete('channel');
|
||||||
|
alpha_url.searchParams.set('channel', 'a');
|
||||||
|
touched_image.src = alpha_url;
|
||||||
|
|
||||||
|
// original image load
|
||||||
|
orig_image.onload = function() {
|
||||||
|
window.dispatchEvent(new Event('resize'));
|
||||||
|
};
|
||||||
|
|
||||||
|
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||||
|
rgb_url.searchParams.delete('channel');
|
||||||
|
rgb_url.searchParams.set('channel', 'rgb');
|
||||||
|
orig_image.src = rgb_url;
|
||||||
|
this.image = orig_image;
|
||||||
|
}
|
||||||
|
|
||||||
|
setEventHandler(maskCanvas) {
|
||||||
|
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||||
|
event.preventDefault();
|
||||||
|
});
|
||||||
|
|
||||||
|
const self = this;
|
||||||
|
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||||
|
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||||
|
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||||
|
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||||
|
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||||
|
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||||
|
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||||
|
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||||
|
}
|
||||||
|
|
||||||
|
brush_size = 10;
|
||||||
|
drawing_mode = false;
|
||||||
|
lastx = -1;
|
||||||
|
lasty = -1;
|
||||||
|
lasttime = 0;
|
||||||
|
|
||||||
|
static handleKeyDown(event) {
|
||||||
|
const self = MaskEditorDialog.instance;
|
||||||
|
if (event.key === ']') {
|
||||||
|
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||||
|
} else if (event.key === '[') {
|
||||||
|
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||||
|
} else if(event.key === 'Enter') {
|
||||||
|
self.save();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.updateBrushPreview(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
static handlePointerUp(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
MaskEditorDialog.instance.drawing_mode = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateBrushPreview(self) {
|
||||||
|
const brush = self.brush;
|
||||||
|
|
||||||
|
var centerX = self.cursorX;
|
||||||
|
var centerY = self.cursorY;
|
||||||
|
|
||||||
|
brush.style.width = self.brush_size * 2 + "px";
|
||||||
|
brush.style.height = self.brush_size * 2 + "px";
|
||||||
|
brush.style.left = (centerX - self.brush_size) + "px";
|
||||||
|
brush.style.top = (centerY - self.brush_size) + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
handleWheelEvent(self, event) {
|
||||||
|
if(event.deltaY < 0)
|
||||||
|
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||||
|
else
|
||||||
|
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||||
|
|
||||||
|
self.brush_slider_input.value = self.brush_size;
|
||||||
|
|
||||||
|
self.updateBrushPreview(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
draw_move(self, event) {
|
||||||
|
event.preventDefault();
|
||||||
|
|
||||||
|
this.cursorX = event.pageX;
|
||||||
|
this.cursorY = event.pageY;
|
||||||
|
|
||||||
|
self.updateBrushPreview(self);
|
||||||
|
|
||||||
|
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
|
||||||
|
var diff = performance.now() - self.lasttime;
|
||||||
|
|
||||||
|
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||||
|
|
||||||
|
var x = event.offsetX;
|
||||||
|
var y = event.offsetY
|
||||||
|
|
||||||
|
if(event.offsetX == null) {
|
||||||
|
x = event.targetTouches[0].clientX - maskRect.left;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(event.offsetY == null) {
|
||||||
|
y = event.targetTouches[0].clientY - maskRect.top;
|
||||||
|
}
|
||||||
|
|
||||||
|
var brush_size = this.brush_size;
|
||||||
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
|
brush_size *= event.pressure;
|
||||||
|
this.last_pressure = event.pressure;
|
||||||
|
}
|
||||||
|
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||||
|
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
|
||||||
|
brush_size *= this.last_pressure;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
brush_size = this.brush_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(diff > 20 && !this.drawing_mode)
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
self.maskCtx.beginPath();
|
||||||
|
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||||
|
self.maskCtx.globalCompositeOperation = "source-over";
|
||||||
|
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||||
|
self.maskCtx.fill();
|
||||||
|
self.lastx = x;
|
||||||
|
self.lasty = y;
|
||||||
|
});
|
||||||
|
else
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
self.maskCtx.beginPath();
|
||||||
|
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||||
|
self.maskCtx.globalCompositeOperation = "source-over";
|
||||||
|
|
||||||
|
var dx = x - self.lastx;
|
||||||
|
var dy = y - self.lasty;
|
||||||
|
|
||||||
|
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||||
|
var directionX = dx / distance;
|
||||||
|
var directionY = dy / distance;
|
||||||
|
|
||||||
|
for (var i = 0; i < distance; i+=5) {
|
||||||
|
var px = self.lastx + (directionX * i);
|
||||||
|
var py = self.lasty + (directionY * i);
|
||||||
|
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||||
|
self.maskCtx.fill();
|
||||||
|
}
|
||||||
|
self.lastx = x;
|
||||||
|
self.lasty = y;
|
||||||
|
});
|
||||||
|
|
||||||
|
self.lasttime = performance.now();
|
||||||
|
}
|
||||||
|
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||||
|
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||||
|
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||||
|
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||||
|
|
||||||
|
var brush_size = this.brush_size;
|
||||||
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
|
brush_size *= event.pressure;
|
||||||
|
this.last_pressure = event.pressure;
|
||||||
|
}
|
||||||
|
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||||
|
brush_size *= this.last_pressure;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
brush_size = this.brush_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
self.maskCtx.beginPath();
|
||||||
|
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||||
|
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||||
|
self.maskCtx.fill();
|
||||||
|
self.lastx = x;
|
||||||
|
self.lasty = y;
|
||||||
|
});
|
||||||
|
else
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
self.maskCtx.beginPath();
|
||||||
|
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||||
|
|
||||||
|
var dx = x - self.lastx;
|
||||||
|
var dy = y - self.lasty;
|
||||||
|
|
||||||
|
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||||
|
var directionX = dx / distance;
|
||||||
|
var directionY = dy / distance;
|
||||||
|
|
||||||
|
for (var i = 0; i < distance; i+=5) {
|
||||||
|
var px = self.lastx + (directionX * i);
|
||||||
|
var py = self.lasty + (directionY * i);
|
||||||
|
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||||
|
self.maskCtx.fill();
|
||||||
|
}
|
||||||
|
self.lastx = x;
|
||||||
|
self.lasty = y;
|
||||||
|
});
|
||||||
|
|
||||||
|
self.lasttime = performance.now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handlePointerDown(self, event) {
|
||||||
|
var brush_size = this.brush_size;
|
||||||
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
|
brush_size *= event.pressure;
|
||||||
|
this.last_pressure = event.pressure;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ([0, 2, 5].includes(event.button)) {
|
||||||
|
self.drawing_mode = true;
|
||||||
|
|
||||||
|
event.preventDefault();
|
||||||
|
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||||
|
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||||
|
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||||
|
|
||||||
|
self.maskCtx.beginPath();
|
||||||
|
if (event.button == 0) {
|
||||||
|
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||||
|
self.maskCtx.globalCompositeOperation = "source-over";
|
||||||
|
} else {
|
||||||
|
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||||
|
}
|
||||||
|
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||||
|
self.maskCtx.fill();
|
||||||
|
self.lastx = x;
|
||||||
|
self.lasty = y;
|
||||||
|
self.lasttime = performance.now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async save() {
|
||||||
|
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||||
|
|
||||||
|
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||||
|
backupCtx.drawImage(this.maskCanvas,
|
||||||
|
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||||
|
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||||
|
|
||||||
|
// paste mask data into alpha channel
|
||||||
|
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||||
|
|
||||||
|
// refine mask image
|
||||||
|
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||||
|
if(backupData.data[i+3] == 255)
|
||||||
|
backupData.data[i+3] = 0;
|
||||||
|
else
|
||||||
|
backupData.data[i+3] = 255;
|
||||||
|
|
||||||
|
backupData.data[i] = 0;
|
||||||
|
backupData.data[i+1] = 0;
|
||||||
|
backupData.data[i+2] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
backupCtx.globalCompositeOperation = 'source-over';
|
||||||
|
backupCtx.putImageData(backupData, 0, 0);
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
const filename = "clipspace-mask-" + performance.now() + ".png";
|
||||||
|
|
||||||
|
const item =
|
||||||
|
{
|
||||||
|
"filename": filename,
|
||||||
|
"subfolder": "clipspace",
|
||||||
|
"type": "input",
|
||||||
|
};
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.images)
|
||||||
|
ComfyApp.clipspace.images[0] = item;
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.widgets) {
|
||||||
|
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||||
|
|
||||||
|
if(index >= 0)
|
||||||
|
ComfyApp.clipspace.widgets[index].value = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dataURL = this.backupCanvas.toDataURL();
|
||||||
|
const blob = dataURLToBlob(dataURL);
|
||||||
|
|
||||||
|
const original_blob = loadedImageToBlob(this.image);
|
||||||
|
|
||||||
|
formData.append('image', blob, filename);
|
||||||
|
formData.append('original_image', original_blob);
|
||||||
|
formData.append('type', "input");
|
||||||
|
formData.append('subfolder', "clipspace");
|
||||||
|
|
||||||
|
this.saveButton.innerText = "Saving...";
|
||||||
|
this.saveButton.disabled = true;
|
||||||
|
await uploadMask(item, formData);
|
||||||
|
ComfyApp.onClipspaceEditorSave();
|
||||||
|
this.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.MaskEditor",
|
||||||
|
init(app) {
|
||||||
|
ComfyApp.open_maskeditor =
|
||||||
|
function () {
|
||||||
|
const dlg = MaskEditorDialog.getInstance();
|
||||||
|
if(!dlg.isOpened()) {
|
||||||
|
dlg.show();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
||||||
|
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
|
||||||
|
}
|
||||||
|
});
|
||||||
@ -300,7 +300,7 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (widget.type === "number") {
|
if (widget.type === "number" || widget.type === "combo") {
|
||||||
addValueControlWidget(this, widget, "fixed");
|
addValueControlWidget(this, widget, "fixed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,5 +14,5 @@
|
|||||||
window.graph = app.graph;
|
window.graph = app.graph;
|
||||||
</script>
|
</script>
|
||||||
</head>
|
</head>
|
||||||
<body></body>
|
<body class="litegraph"></body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
|
|
||||||
//when clicked on top of a node
|
//when clicked on top of a node
|
||||||
//and it is not interactive
|
//and it is not interactive
|
||||||
if (node && this.allow_interaction && !skip_action && !this.read_only) {
|
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
|
||||||
if (!this.live_mode && !node.flags.pinned) {
|
if (!this.live_mode && !node.flags.pinned) {
|
||||||
this.bringToFront(node);
|
this.bringToFront(node);
|
||||||
} //if it wasn't selected?
|
} //if it wasn't selected?
|
||||||
|
|
||||||
//not dragging mouse to connect two slots
|
//not dragging mouse to connect two slots
|
||||||
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||||
//Search for corner for resize
|
//Search for corner for resize
|
||||||
if ( !skip_action &&
|
if ( !skip_action &&
|
||||||
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
||||||
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
}
|
}
|
||||||
|
|
||||||
//double clicking
|
//double clicking
|
||||||
if (is_double_click && this.selected_nodes[node.id]) {
|
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
|
||||||
//double click node
|
//double click node
|
||||||
if (node.onDblClick) {
|
if (node.onDblClick) {
|
||||||
node.onDblClick( e, pos, this );
|
node.onDblClick( e, pos, this );
|
||||||
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//get node over
|
||||||
|
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||||
|
|
||||||
if (this.dragging_rectangle)
|
if (this.dragging_rectangle)
|
||||||
{
|
{
|
||||||
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
||||||
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
this.ds.offset[1] += delta[1] / this.ds.scale;
|
this.ds.offset[1] += delta[1] / this.ds.scale;
|
||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
this.dirty_bgcanvas = true;
|
this.dirty_bgcanvas = true;
|
||||||
} else if (this.allow_interaction && !this.read_only) {
|
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
|
||||||
if (this.connecting_node) {
|
if (this.connecting_node) {
|
||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//get node over
|
|
||||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
|
||||||
|
|
||||||
//remove mouseover flag
|
//remove mouseover flag
|
||||||
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
||||||
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
||||||
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (show_text) {
|
if (show_text) {
|
||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "toggle":
|
case "toggle":
|
||||||
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.fill();
|
ctx.fill();
|
||||||
if (show_text) {
|
if (show_text) {
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
if (w.name != null) {
|
const label = w.label || w.name;
|
||||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
if (label != null) {
|
||||||
|
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||||
}
|
}
|
||||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(
|
ctx.fillText(
|
||||||
w.name + " " + Number(w.value).toFixed(3),
|
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||||
widget_width * 0.5,
|
widget_width * 0.5,
|
||||||
y + H * 0.7
|
y + H * 0.7
|
||||||
);
|
);
|
||||||
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.fill();
|
ctx.fill();
|
||||||
}
|
}
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
if (w.type == "number") {
|
if (w.type == "number") {
|
||||||
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
|
|
||||||
//ctx.stroke();
|
//ctx.stroke();
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
if (w.name != null) {
|
const label = w.label || w.name;
|
||||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
if (label != null) {
|
||||||
|
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||||
}
|
}
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
event,
|
event,
|
||||||
active_widget
|
active_widget
|
||||||
) {
|
) {
|
||||||
if (!node.widgets || !node.widgets.length) {
|
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
canvas.graph.add(group);
|
canvas.graph.add(group);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determines the furthest nodes in each direction
|
||||||
|
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
|
||||||
|
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||||
|
*/
|
||||||
|
LGraphCanvas.getBoundaryNodes = function(nodes) {
|
||||||
|
let top = null;
|
||||||
|
let right = null;
|
||||||
|
let bottom = null;
|
||||||
|
let left = null;
|
||||||
|
for (const nID in nodes) {
|
||||||
|
const node = nodes[nID];
|
||||||
|
const [x, y] = node.pos;
|
||||||
|
const [width, height] = node.size;
|
||||||
|
|
||||||
|
if (top === null || y < top.pos[1]) {
|
||||||
|
top = node;
|
||||||
|
}
|
||||||
|
if (right === null || x + width > right.pos[0] + right.size[0]) {
|
||||||
|
right = node;
|
||||||
|
}
|
||||||
|
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
|
||||||
|
bottom = node;
|
||||||
|
}
|
||||||
|
if (left === null || x < left.pos[0]) {
|
||||||
|
left = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"top": top,
|
||||||
|
"right": right,
|
||||||
|
"bottom": bottom,
|
||||||
|
"left": left
|
||||||
|
};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Determines the furthest nodes in each direction for the currently selected nodes
|
||||||
|
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||||
|
*/
|
||||||
|
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
|
||||||
|
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param {LGraphNode[]} nodes a list of nodes
|
||||||
|
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
|
||||||
|
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
|
||||||
|
*/
|
||||||
|
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
|
||||||
|
if (!nodes) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const canvas = LGraphCanvas.active_canvas;
|
||||||
|
let boundaryNodes = []
|
||||||
|
if (align_to === undefined) {
|
||||||
|
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
|
||||||
|
} else {
|
||||||
|
boundaryNodes = {
|
||||||
|
"top": align_to,
|
||||||
|
"right": align_to,
|
||||||
|
"bottom": align_to,
|
||||||
|
"left": align_to
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
|
||||||
|
switch (direction) {
|
||||||
|
case "right":
|
||||||
|
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
|
||||||
|
break;
|
||||||
|
case "left":
|
||||||
|
node.pos[0] = boundaryNodes["left"].pos[0];
|
||||||
|
break;
|
||||||
|
case "top":
|
||||||
|
node.pos[1] = boundaryNodes["top"].pos[1];
|
||||||
|
break;
|
||||||
|
case "bottom":
|
||||||
|
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
canvas.dirty_canvas = true;
|
||||||
|
canvas.dirty_bgcanvas = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
|
||||||
|
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||||
|
event: event,
|
||||||
|
callback: inner_clicked,
|
||||||
|
parentMenu: prev_menu,
|
||||||
|
});
|
||||||
|
|
||||||
|
function inner_clicked(value) {
|
||||||
|
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
|
||||||
|
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||||
|
event: event,
|
||||||
|
callback: inner_clicked,
|
||||||
|
parentMenu: prev_menu,
|
||||||
|
});
|
||||||
|
|
||||||
|
function inner_clicked(value) {
|
||||||
|
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
||||||
|
|
||||||
var canvas = LGraphCanvas.active_canvas;
|
var canvas = LGraphCanvas.active_canvas;
|
||||||
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
if (Object.keys(this.selected_nodes).length > 1) {
|
||||||
|
options.push({
|
||||||
|
content: "Align",
|
||||||
|
has_submenu: true,
|
||||||
|
callback: LGraphCanvas.onGroupAlign,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if (this._graph_stack && this._graph_stack.length > 0) {
|
if (this._graph_stack && this._graph_stack.length > 0) {
|
||||||
options.push(null, {
|
options.push(null, {
|
||||||
content: "Close subgraph",
|
content: "Close subgraph",
|
||||||
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
callback: LGraphCanvas.onMenuNodeToSubgraph
|
callback: LGraphCanvas.onMenuNodeToSubgraph
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (Object.keys(this.selected_nodes).length > 1) {
|
||||||
|
options.push({
|
||||||
|
content: "Align Selected To",
|
||||||
|
has_submenu: true,
|
||||||
|
callback: LGraphCanvas.onNodeAlign,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
options.push(null, {
|
options.push(null, {
|
||||||
content: "Remove",
|
content: "Remove",
|
||||||
disabled: !(node.removable !== false && !node.block_delete ),
|
disabled: !(node.removable !== false && !node.block_delete ),
|
||||||
|
|||||||
@ -88,6 +88,12 @@ class ComfyApi extends EventTarget {
|
|||||||
case "executed":
|
case "executed":
|
||||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
|
case "execution_start":
|
||||||
|
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||||
|
break;
|
||||||
|
case "execution_error":
|
||||||
|
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
if (this.#registered.has(msg.type)) {
|
if (this.#registered.has(msg.type)) {
|
||||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||||
@ -163,7 +169,7 @@ class ComfyApi extends EventTarget {
|
|||||||
|
|
||||||
if (res.status !== 200) {
|
if (res.status !== 200) {
|
||||||
throw {
|
throw {
|
||||||
response: await res.text(),
|
response: await res.json(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
|
|||||||
import { ComfyUI, $el } from "./ui.js";
|
import { ComfyUI, $el } from "./ui.js";
|
||||||
import { api } from "./api.js";
|
import { api } from "./api.js";
|
||||||
import { defaultGraph } from "./defaultGraph.js";
|
import { defaultGraph } from "./defaultGraph.js";
|
||||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||||
@ -25,6 +25,9 @@ export class ComfyApp {
|
|||||||
* @type {serialized node object}
|
* @type {serialized node object}
|
||||||
*/
|
*/
|
||||||
static clipspace = null;
|
static clipspace = null;
|
||||||
|
static clipspace_invalidate_handler = null;
|
||||||
|
static open_maskeditor = null;
|
||||||
|
static clipspace_return_node = null;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
@ -48,6 +51,114 @@ export class ComfyApp {
|
|||||||
this.shiftDown = false;
|
this.shiftDown = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static isImageNode(node) {
|
||||||
|
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static onClipspaceEditorSave() {
|
||||||
|
if(ComfyApp.clipspace_return_node) {
|
||||||
|
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static onClipspaceEditorClosed() {
|
||||||
|
ComfyApp.clipspace_return_node = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
static copyToClipspace(node) {
|
||||||
|
var widgets = null;
|
||||||
|
if(node.widgets) {
|
||||||
|
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||||
|
}
|
||||||
|
|
||||||
|
var imgs = undefined;
|
||||||
|
var orig_imgs = undefined;
|
||||||
|
if(node.imgs != undefined) {
|
||||||
|
imgs = [];
|
||||||
|
orig_imgs = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < node.imgs.length; i++) {
|
||||||
|
imgs[i] = new Image();
|
||||||
|
imgs[i].src = node.imgs[i].src;
|
||||||
|
orig_imgs[i] = imgs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var selectedIndex = 0;
|
||||||
|
if(node.imageIndex) {
|
||||||
|
selectedIndex = node.imageIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComfyApp.clipspace = {
|
||||||
|
'widgets': widgets,
|
||||||
|
'imgs': imgs,
|
||||||
|
'original_imgs': orig_imgs,
|
||||||
|
'images': node.images,
|
||||||
|
'selectedIndex': selectedIndex,
|
||||||
|
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
||||||
|
};
|
||||||
|
|
||||||
|
ComfyApp.clipspace_return_node = null;
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace_invalidate_handler) {
|
||||||
|
ComfyApp.clipspace_invalidate_handler();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static pasteFromClipspace(node) {
|
||||||
|
if(ComfyApp.clipspace) {
|
||||||
|
// image paste
|
||||||
|
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||||
|
if(node.images && ComfyApp.clipspace.images) {
|
||||||
|
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||||
|
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.imgs) {
|
||||||
|
// deep-copy to cut link with clipspace
|
||||||
|
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||||
|
const img = new Image();
|
||||||
|
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||||
|
node.imgs = [img];
|
||||||
|
node.imageIndex = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const imgs = [];
|
||||||
|
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
||||||
|
imgs[i] = new Image();
|
||||||
|
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
||||||
|
node.imgs = imgs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.widgets) {
|
||||||
|
if(ComfyApp.clipspace.images) {
|
||||||
|
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||||
|
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||||
|
if(index >= 0) {
|
||||||
|
node.widgets[index].value = clip_image;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(ComfyApp.clipspace.widgets) {
|
||||||
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
|
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
|
if (prop && prop.type != 'button') {
|
||||||
|
prop.value = value;
|
||||||
|
prop.callback(value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.graph.setDirtyCanvas(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoke an extension callback
|
* Invoke an extension callback
|
||||||
* @param {keyof ComfyExtension} method The extension callback to execute
|
* @param {keyof ComfyExtension} method The extension callback to execute
|
||||||
@ -137,81 +248,30 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.push(
|
// prevent conflict of clipspace content
|
||||||
{
|
if(!ComfyApp.clipspace_return_node) {
|
||||||
content: "Copy (Clipspace)",
|
options.push({
|
||||||
callback: (obj) => {
|
content: "Copy (Clipspace)",
|
||||||
var widgets = null;
|
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||||
if(this.widgets) {
|
});
|
||||||
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
|
||||||
}
|
|
||||||
|
|
||||||
let img = new Image();
|
|
||||||
var imgs = undefined;
|
|
||||||
if(this.imgs != undefined) {
|
|
||||||
img.src = this.imgs[0].src;
|
|
||||||
imgs = [img];
|
|
||||||
}
|
|
||||||
|
|
||||||
ComfyApp.clipspace = {
|
if(ComfyApp.clipspace != null) {
|
||||||
'widgets': widgets,
|
options.push({
|
||||||
'imgs': imgs,
|
content: "Paste (Clipspace)",
|
||||||
'original_imgs': imgs,
|
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||||
'images': this.images
|
});
|
||||||
};
|
}
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if(ComfyApp.clipspace != null) {
|
if(ComfyApp.isImageNode(this)) {
|
||||||
options.push(
|
options.push({
|
||||||
{
|
content: "Open in MaskEditor",
|
||||||
content: "Paste (Clipspace)",
|
callback: (obj) => {
|
||||||
callback: () => {
|
ComfyApp.copyToClipspace(this);
|
||||||
if(ComfyApp.clipspace != null) {
|
ComfyApp.clipspace_return_node = this;
|
||||||
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
|
ComfyApp.open_maskeditor();
|
||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
|
||||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
|
||||||
if (prop) {
|
|
||||||
prop.callback(value);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// image paste
|
|
||||||
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
|
|
||||||
var filename = "";
|
|
||||||
if(this.images && ComfyApp.clipspace.images) {
|
|
||||||
this.images = ComfyApp.clipspace.images;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(ComfyApp.clipspace.images != undefined) {
|
|
||||||
const clip_image = ComfyApp.clipspace.images[0];
|
|
||||||
if(clip_image.subfolder != '')
|
|
||||||
filename = `${clip_image.subfolder}/`;
|
|
||||||
filename += `${clip_image.filename} [${clip_image.type}]`;
|
|
||||||
}
|
|
||||||
else if(ComfyApp.clipspace.widgets != undefined) {
|
|
||||||
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
|
||||||
if(index_in_clip >= 0) {
|
|
||||||
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const index = this.widgets.findIndex(obj => obj.name === 'image');
|
|
||||||
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
|
|
||||||
this.imgs = ComfyApp.clipspace.imgs;
|
|
||||||
|
|
||||||
this.widgets[index].value = filename;
|
|
||||||
if(this.widgets_values != undefined) {
|
|
||||||
this.widgets_values[index] = filename;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.trigger('changed');
|
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -711,16 +771,27 @@ export class ComfyApp {
|
|||||||
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
|
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
|
||||||
const res = origDrawNodeShape.apply(this, arguments);
|
const res = origDrawNodeShape.apply(this, arguments);
|
||||||
|
|
||||||
|
const nodeErrors = self.lastPromptError?.node_errors[node.id];
|
||||||
|
|
||||||
let color = null;
|
let color = null;
|
||||||
|
let lineWidth = 1;
|
||||||
if (node.id === +self.runningNodeId) {
|
if (node.id === +self.runningNodeId) {
|
||||||
color = "#0f0";
|
color = "#0f0";
|
||||||
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
|
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
|
||||||
color = "dodgerblue";
|
color = "dodgerblue";
|
||||||
}
|
}
|
||||||
|
else if (self.lastPromptError != null && nodeErrors?.errors) {
|
||||||
|
color = "red";
|
||||||
|
lineWidth = 2;
|
||||||
|
}
|
||||||
|
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
|
||||||
|
color = "#f0f";
|
||||||
|
lineWidth = 2;
|
||||||
|
}
|
||||||
|
|
||||||
if (color) {
|
if (color) {
|
||||||
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
|
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
|
||||||
ctx.lineWidth = 1;
|
ctx.lineWidth = lineWidth;
|
||||||
ctx.globalAlpha = 0.8;
|
ctx.globalAlpha = 0.8;
|
||||||
ctx.beginPath();
|
ctx.beginPath();
|
||||||
if (shape == LiteGraph.BOX_SHAPE)
|
if (shape == LiteGraph.BOX_SHAPE)
|
||||||
@ -747,11 +818,28 @@ export class ComfyApp {
|
|||||||
ctx.stroke();
|
ctx.stroke();
|
||||||
ctx.strokeStyle = fgcolor;
|
ctx.strokeStyle = fgcolor;
|
||||||
ctx.globalAlpha = 1;
|
ctx.globalAlpha = 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (self.progress) {
|
if (self.progress && node.id === +self.runningNodeId) {
|
||||||
ctx.fillStyle = "green";
|
ctx.fillStyle = "green";
|
||||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||||
ctx.fillStyle = bgcolor;
|
ctx.fillStyle = bgcolor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Highlight inputs that failed validation
|
||||||
|
if (nodeErrors) {
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.strokeStyle = "red";
|
||||||
|
for (const error of nodeErrors.errors) {
|
||||||
|
if (error.extra_info && error.extra_info.input_name) {
|
||||||
|
const inputIndex = node.findInputSlot(error.extra_info.input_name)
|
||||||
|
if (inputIndex !== -1) {
|
||||||
|
let pos = node.getConnectionPos(true, inputIndex);
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
|
||||||
|
ctx.stroke();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -809,6 +897,17 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
api.addEventListener("execution_start", ({ detail }) => {
|
||||||
|
this.lastExecutionError = null
|
||||||
|
});
|
||||||
|
|
||||||
|
api.addEventListener("execution_error", ({ detail }) => {
|
||||||
|
this.lastExecutionError = detail;
|
||||||
|
const formattedError = this.#formatExecutionError(detail);
|
||||||
|
this.ui.dialog.show(formattedError);
|
||||||
|
this.canvas.draw(true, true);
|
||||||
|
});
|
||||||
|
|
||||||
api.init();
|
api.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -842,7 +941,9 @@ export class ComfyApp {
|
|||||||
await this.#loadExtensions();
|
await this.#loadExtensions();
|
||||||
|
|
||||||
// Create and mount the LiteGraph in the DOM
|
// Create and mount the LiteGraph in the DOM
|
||||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
const mainCanvas = document.createElement("canvas")
|
||||||
|
mainCanvas.style.touchAction = "none"
|
||||||
|
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||||
canvasEl.tabIndex = "1";
|
canvasEl.tabIndex = "1";
|
||||||
document.body.prepend(canvasEl);
|
document.body.prepend(canvasEl);
|
||||||
|
|
||||||
@ -909,6 +1010,11 @@ export class ComfyApp {
|
|||||||
const app = this;
|
const app = this;
|
||||||
// Load node definitions from the backend
|
// Load node definitions from the backend
|
||||||
const defs = await api.getNodeDefs();
|
const defs = await api.getNodeDefs();
|
||||||
|
await this.registerNodesFromDefs(defs);
|
||||||
|
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||||
|
}
|
||||||
|
|
||||||
|
async registerNodesFromDefs(defs) {
|
||||||
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
||||||
|
|
||||||
// Generate list of known widgets
|
// Generate list of known widgets
|
||||||
@ -954,7 +1060,8 @@ export class ComfyApp {
|
|||||||
for (const o in nodeData["output"]) {
|
for (const o in nodeData["output"]) {
|
||||||
const output = nodeData["output"][o];
|
const output = nodeData["output"][o];
|
||||||
const outputName = nodeData["output_name"][o] || output;
|
const outputName = nodeData["output_name"][o] || output;
|
||||||
this.addOutput(outputName, output);
|
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||||
|
this.addOutput(outputName, output, { shape: outputShape });
|
||||||
}
|
}
|
||||||
|
|
||||||
const s = this.computeSize();
|
const s = this.computeSize();
|
||||||
@ -980,8 +1087,6 @@ export class ComfyApp {
|
|||||||
LiteGraph.registerNodeType(nodeId, node);
|
LiteGraph.registerNodeType(nodeId, node);
|
||||||
node.category = nodeData.category;
|
node.category = nodeData.category;
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1180,6 +1285,43 @@ export class ComfyApp {
|
|||||||
return { workflow, output };
|
return { workflow, output };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#formatPromptError(error) {
|
||||||
|
if (error == null) {
|
||||||
|
return "(unknown error)"
|
||||||
|
}
|
||||||
|
else if (typeof error === "string") {
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
else if (error.stack && error.message) {
|
||||||
|
return error.toString()
|
||||||
|
}
|
||||||
|
else if (error.response) {
|
||||||
|
let message = error.response.error.message;
|
||||||
|
if (error.response.error.details)
|
||||||
|
message += ": " + error.response.error.details;
|
||||||
|
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
|
||||||
|
message += "\n" + nodeError.class_type + ":"
|
||||||
|
for (const errorReason of nodeError.errors) {
|
||||||
|
message += "\n - " + errorReason.message + ": " + errorReason.details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
return "(unknown error)"
|
||||||
|
}
|
||||||
|
|
||||||
|
#formatExecutionError(error) {
|
||||||
|
if (error == null) {
|
||||||
|
return "(unknown error)"
|
||||||
|
}
|
||||||
|
|
||||||
|
const traceback = error.traceback.join("")
|
||||||
|
const nodeId = error.node_id
|
||||||
|
const nodeType = error.node_type
|
||||||
|
|
||||||
|
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
|
||||||
|
}
|
||||||
|
|
||||||
async queuePrompt(number, batchCount = 1) {
|
async queuePrompt(number, batchCount = 1) {
|
||||||
this.#queueItems.push({ number, batchCount });
|
this.#queueItems.push({ number, batchCount });
|
||||||
|
|
||||||
@ -1187,8 +1329,10 @@ export class ComfyApp {
|
|||||||
if (this.#processingQueue) {
|
if (this.#processingQueue) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.#processingQueue = true;
|
this.#processingQueue = true;
|
||||||
|
this.lastPromptError = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (this.#queueItems.length) {
|
while (this.#queueItems.length) {
|
||||||
({ number, batchCount } = this.#queueItems.pop());
|
({ number, batchCount } = this.#queueItems.pop());
|
||||||
@ -1199,7 +1343,12 @@ export class ComfyApp {
|
|||||||
try {
|
try {
|
||||||
await api.queuePrompt(number, p);
|
await api.queuePrompt(number, p);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.ui.dialog.show(error.response || error.toString());
|
const formattedError = this.#formatPromptError(error)
|
||||||
|
this.ui.dialog.show(formattedError);
|
||||||
|
if (error.response) {
|
||||||
|
this.lastPromptError = error.response;
|
||||||
|
this.canvas.draw(true, true);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1245,6 +1394,11 @@ export class ComfyApp {
|
|||||||
this.loadGraphData(JSON.parse(reader.result));
|
this.loadGraphData(JSON.parse(reader.result));
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
|
} else if (file.name?.endsWith(".latent")) {
|
||||||
|
const info = await getLatentMetadata(file);
|
||||||
|
if (info.workflow) {
|
||||||
|
this.loadGraphData(JSON.parse(info.workflow));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1273,14 +1427,19 @@ export class ComfyApp {
|
|||||||
|
|
||||||
const def = defs[node.type];
|
const def = defs[node.type];
|
||||||
|
|
||||||
|
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
|
||||||
|
// and additional work is needed to consider the primitive logic in the refresh logic.
|
||||||
|
if(!def)
|
||||||
|
continue;
|
||||||
|
|
||||||
for(const widgetNum in node.widgets) {
|
for(const widgetNum in node.widgets) {
|
||||||
const widget = node.widgets[widgetNum]
|
const widget = node.widgets[widgetNum]
|
||||||
|
|
||||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||||
widget.options.values = def["input"]["required"][widget.name][0];
|
widget.options.values = def["input"]["required"][widget.name][0];
|
||||||
|
|
||||||
if(!widget.options.values.includes(widget.value)) {
|
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
|
||||||
widget.value = widget.options.values[0];
|
widget.value = widget.options.values[0];
|
||||||
|
widget.callback(widget.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1292,6 +1451,8 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
clean() {
|
clean() {
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
this.lastPromptError = null;
|
||||||
|
this.lastExecutionError = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -47,12 +47,29 @@ export function getPngMetadata(file) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getLatentMetadata(file) {
|
||||||
|
return new Promise((r) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (event) => {
|
||||||
|
const safetensorsData = new Uint8Array(event.target.result);
|
||||||
|
const dataView = new DataView(safetensorsData.buffer);
|
||||||
|
let header_size = dataView.getUint32(0, true);
|
||||||
|
let offset = 8;
|
||||||
|
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||||
|
r(header.__metadata__);
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.readAsArrayBuffer(file);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export async function importA1111(graph, parameters) {
|
export async function importA1111(graph, parameters) {
|
||||||
const p = parameters.lastIndexOf("\nSteps:");
|
const p = parameters.lastIndexOf("\nSteps:");
|
||||||
if (p > -1) {
|
if (p > -1) {
|
||||||
const embeddings = await api.getEmbeddings();
|
const embeddings = await api.getEmbeddings();
|
||||||
const opts = parameters
|
const opts = parameters
|
||||||
.substr(p)
|
.substr(p)
|
||||||
|
.split("\n")[1]
|
||||||
.split(",")
|
.split(",")
|
||||||
.reduce((p, n) => {
|
.reduce((p, n) => {
|
||||||
const s = n.split(":");
|
const s = n.split(":");
|
||||||
|
|||||||
@ -465,7 +465,7 @@ export class ComfyUI {
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png",
|
accept: ".json,image/png,.latent",
|
||||||
style: { display: "none" },
|
style: { display: "none" },
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
@ -581,6 +581,7 @@ export class ComfyUI {
|
|||||||
}),
|
}),
|
||||||
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
||||||
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||||
|
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
|
||||||
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
||||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||||
app.clean();
|
app.clean();
|
||||||
|
|||||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
|||||||
|
|
||||||
var v = valueControl.value;
|
var v = valueControl.value;
|
||||||
|
|
||||||
let min = targetWidget.options.min;
|
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||||
let max = targetWidget.options.max;
|
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||||
// limit to something that javascript can handle
|
let current_length = targetWidget.options.values.length;
|
||||||
max = Math.min(1125899906842624, max);
|
|
||||||
min = Math.max(-1125899906842624, min);
|
|
||||||
let range = (max - min) / (targetWidget.options.step / 10);
|
|
||||||
|
|
||||||
//adjust values based on valueControl Behaviour
|
switch (v) {
|
||||||
switch (v) {
|
case "increment":
|
||||||
case "fixed":
|
current_index += 1;
|
||||||
break;
|
break;
|
||||||
case "increment":
|
case "decrement":
|
||||||
targetWidget.value += targetWidget.options.step / 10;
|
current_index -= 1;
|
||||||
break;
|
break;
|
||||||
case "decrement":
|
case "randomize":
|
||||||
targetWidget.value -= targetWidget.options.step / 10;
|
current_index = Math.floor(Math.random() * current_length);
|
||||||
break;
|
default:
|
||||||
case "randomize":
|
break;
|
||||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
}
|
||||||
default:
|
current_index = Math.max(0, current_index);
|
||||||
break;
|
current_index = Math.min(current_length - 1, current_index);
|
||||||
|
if (current_index >= 0) {
|
||||||
|
let value = targetWidget.options.values[current_index];
|
||||||
|
targetWidget.value = value;
|
||||||
|
targetWidget.callback(value);
|
||||||
|
}
|
||||||
|
} else { //number
|
||||||
|
let min = targetWidget.options.min;
|
||||||
|
let max = targetWidget.options.max;
|
||||||
|
// limit to something that javascript can handle
|
||||||
|
max = Math.min(1125899906842624, max);
|
||||||
|
min = Math.max(-1125899906842624, min);
|
||||||
|
let range = (max - min) / (targetWidget.options.step / 10);
|
||||||
|
|
||||||
|
//adjust values based on valueControl Behaviour
|
||||||
|
switch (v) {
|
||||||
|
case "fixed":
|
||||||
|
break;
|
||||||
|
case "increment":
|
||||||
|
targetWidget.value += targetWidget.options.step / 10;
|
||||||
|
break;
|
||||||
|
case "decrement":
|
||||||
|
targetWidget.value -= targetWidget.options.step / 10;
|
||||||
|
break;
|
||||||
|
case "randomize":
|
||||||
|
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
/*check if values are over or under their respective
|
||||||
|
* ranges and set them to min or max.*/
|
||||||
|
if (targetWidget.value < min)
|
||||||
|
targetWidget.value = min;
|
||||||
|
|
||||||
|
if (targetWidget.value > max)
|
||||||
|
targetWidget.value = max;
|
||||||
}
|
}
|
||||||
/*check if values are over or under their respective
|
|
||||||
* ranges and set them to min or max.*/
|
|
||||||
if (targetWidget.value < min)
|
|
||||||
targetWidget.value = min;
|
|
||||||
|
|
||||||
if (targetWidget.value > max)
|
|
||||||
targetWidget.value = max;
|
|
||||||
}
|
}
|
||||||
return valueControl;
|
return valueControl;
|
||||||
};
|
};
|
||||||
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
|
|||||||
computeSize(node.size);
|
computeSize(node.size);
|
||||||
}
|
}
|
||||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||||
const t = ctx.getTransform();
|
|
||||||
const margin = 10;
|
const margin = 10;
|
||||||
|
const elRect = ctx.canvas.getBoundingClientRect();
|
||||||
|
const transform = new DOMMatrix()
|
||||||
|
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||||
|
.multiplySelf(ctx.getTransform())
|
||||||
|
.translateSelf(margin, margin + y);
|
||||||
|
|
||||||
Object.assign(this.inputEl.style, {
|
Object.assign(this.inputEl.style, {
|
||||||
left: `${t.a * margin + t.e}px`,
|
transformOrigin: "0 0",
|
||||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
transform: transform,
|
||||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
left: "0px",
|
||||||
background: (!node.color)?'':node.color,
|
top: "0px",
|
||||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
width: `${widgetWidth - (margin * 2)}px`,
|
||||||
|
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||||
position: "absolute",
|
position: "absolute",
|
||||||
|
background: (!node.color)?'':node.color,
|
||||||
color: (!node.color)?'':'white',
|
color: (!node.color)?'':'white',
|
||||||
zIndex: app.graph._nodes.indexOf(node),
|
zIndex: app.graph._nodes.indexOf(node),
|
||||||
fontSize: `${t.d * 10.0}px`,
|
|
||||||
});
|
});
|
||||||
this.inputEl.hidden = !visible;
|
this.inputEl.hidden = !visible;
|
||||||
},
|
},
|
||||||
@ -266,10 +297,46 @@ export const ComfyWidgets = {
|
|||||||
node.imgs = [img];
|
node.imgs = [img];
|
||||||
app.graph.setDirtyCanvas(true);
|
app.graph.setDirtyCanvas(true);
|
||||||
};
|
};
|
||||||
img.src = `/view?filename=${name}&type=input`;
|
let folder_separator = name.lastIndexOf("/");
|
||||||
|
let subfolder = "";
|
||||||
|
if (folder_separator > -1) {
|
||||||
|
subfolder = name.substring(0, folder_separator);
|
||||||
|
name = name.substring(folder_separator + 1);
|
||||||
|
}
|
||||||
|
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
|
||||||
node.setSizeForImage?.();
|
node.setSizeForImage?.();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var default_value = imageWidget.value;
|
||||||
|
Object.defineProperty(imageWidget, "value", {
|
||||||
|
set : function(value) {
|
||||||
|
this._real_value = value;
|
||||||
|
},
|
||||||
|
|
||||||
|
get : function() {
|
||||||
|
let value = "";
|
||||||
|
if (this._real_value) {
|
||||||
|
value = this._real_value;
|
||||||
|
} else {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value.filename) {
|
||||||
|
let real_value = value;
|
||||||
|
value = "";
|
||||||
|
if (real_value.subfolder) {
|
||||||
|
value = real_value.subfolder + "/";
|
||||||
|
}
|
||||||
|
|
||||||
|
value += real_value.filename;
|
||||||
|
|
||||||
|
if(real_value.type && real_value.type !== "input")
|
||||||
|
value += ` [${real_value.type}]`;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Add our own callback to the combo widget to render an image when it changes
|
// Add our own callback to the combo widget to render an image when it changes
|
||||||
const cb = node.callback;
|
const cb = node.callback;
|
||||||
imageWidget.callback = function () {
|
imageWidget.callback = function () {
|
||||||
|
|||||||
@ -39,6 +39,8 @@ body {
|
|||||||
padding: 2px;
|
padding: 2px;
|
||||||
resize: none;
|
resize: none;
|
||||||
border: none;
|
border: none;
|
||||||
|
box-sizing: border-box;
|
||||||
|
font-size: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal {
|
.comfy-modal {
|
||||||
@ -287,6 +289,11 @@ button.comfy-queue-btn {
|
|||||||
|
|
||||||
/* Context menu */
|
/* Context menu */
|
||||||
|
|
||||||
|
.litegraph .dialog {
|
||||||
|
z-index: 1;
|
||||||
|
font-family: Arial;
|
||||||
|
}
|
||||||
|
|
||||||
.litegraph .litemenu-entry.has_submenu {
|
.litegraph .litemenu-entry.has_submenu {
|
||||||
position: relative;
|
position: relative;
|
||||||
padding-right: 20px;
|
padding-right: 20px;
|
||||||
@ -329,6 +336,7 @@ button.comfy-queue-btn {
|
|||||||
z-index: 9999 !important;
|
z-index: 9999 !important;
|
||||||
background-color: var(--comfy-menu-bg) !important;
|
background-color: var(--comfy-menu-bg) !important;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
display: block;
|
||||||
}
|
}
|
||||||
|
|
||||||
.litegraph.litesearchbox input,
|
.litegraph.litesearchbox input,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user