mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 12:50:49 +08:00
Merge branch 'comfyanonymous:master' into feature/blockweights
This commit is contained in:
commit
23332731bd
@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
|
||||
else:
|
||||
raise AssertionError('Unknown merge analysis result')
|
||||
|
||||
|
||||
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
||||
repo = pygit2.Repository(str(sys.argv[1]))
|
||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||
try:
|
||||
|
||||
@ -30,6 +30,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
|
||||
@ -17,6 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11.3'
|
||||
|
||||
@ -1,14 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
|
||||
|
||||
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
return ModelPatcher(model), clip, vae
|
||||
|
||||
111
comfy/diffusers_load.py
Normal file
111
comfy/diffusers_load.py
Normal file
@ -0,0 +1,111 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
||||
import os.path as osp
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import diffusers_convert
|
||||
|
||||
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
|
||||
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
|
||||
|
||||
# magic
|
||||
v2 = diffusers_unet_conf["sample_size"] == 96
|
||||
if 'prediction_type' in diffusers_scheduler_conf:
|
||||
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
|
||||
|
||||
if v2:
|
||||
if v_pred:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
|
||||
else:
|
||||
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
|
||||
|
||||
with open(config_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
vae_config = model_config_params['first_stage_config']
|
||||
vae_config['scale_factor'] = scale_factor
|
||||
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
|
||||
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
||||
if is_v20_model:
|
||||
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
||||
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
||||
else:
|
||||
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
load_state_dict_to = []
|
||||
if output_vae:
|
||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
if output_clip:
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
|
||||
model = instantiate_from_config(config["model"])
|
||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
return ModelPatcher(model), clip, vae
|
||||
@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
return x+h
|
||||
|
||||
def slice_attention(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1])**(-0.5))
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
return r1
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
scale = (int(c)**(-0.5))
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
@ -1,23 +1,29 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import torch
|
||||
|
||||
class VRAMState(Enum):
|
||||
CPU = 0
|
||||
DISABLED = 0
|
||||
NO_VRAM = 1
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
MPS = 5
|
||||
SHARED = 5
|
||||
|
||||
class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
total_vram_available_mb = -1
|
||||
|
||||
accelerate_enabled = False
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
directml_enabled = False
|
||||
@ -31,30 +37,80 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import torch
|
||||
if directml_enabled:
|
||||
total_vram = 4097 #TODO
|
||||
else:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
|
||||
except:
|
||||
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
if not args.normalvram and not args.cpu:
|
||||
if total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
cpu_state = CPUState.MPS
|
||||
except:
|
||||
pass
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
global cpu_state
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if cpu_state == CPUState.MPS:
|
||||
return torch.device("mps")
|
||||
if cpu_state == CPUState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
mem_total_torch = mem_total
|
||||
elif xpu_available:
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_cuda
|
||||
|
||||
if torch_total_too:
|
||||
return (mem_total, mem_total_torch)
|
||||
else:
|
||||
return mem_total
|
||||
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
if not args.normalvram and not args.cpu:
|
||||
if lowvram_available and total_vram <= 4096:
|
||||
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
elif total_vram > total_ram * 1.1 and total_vram > 14336:
|
||||
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
|
||||
vram_state = VRAMState.HIGH_VRAM
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
except:
|
||||
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
lowvram_available = True
|
||||
elif args.novram:
|
||||
set_vram_to = VRAMState.NO_VRAM
|
||||
elif args.highvram:
|
||||
@ -102,32 +159,42 @@ if args.force_fp32:
|
||||
print("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
accelerate_enabled = True
|
||||
vram_state = set_vram_to
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
|
||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||
lowvram_available = False
|
||||
|
||||
total_vram_available_mb = (total_vram - 1024) // 2
|
||||
total_vram_available_mb = int(max(256, total_vram_available_mb))
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
vram_state = VRAMState.MPS
|
||||
except:
|
||||
pass
|
||||
if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if args.cpu:
|
||||
vram_state = VRAMState.CPU
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
if device.type == "cuda":
|
||||
return "{} {}".format(device, torch.cuda.get_device_name(device))
|
||||
else:
|
||||
return "{}".format(device.type)
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Device:", get_torch_device_name(get_torch_device()))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
current_gpu_controlnets = []
|
||||
|
||||
@ -173,22 +240,29 @@ def load_model_gpu(model):
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
model.model_patches_to(get_torch_device())
|
||||
torch_dev = get_torch_device()
|
||||
model.model_patches_to(torch_dev)
|
||||
|
||||
vram_set_state = vram_state
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
current_loaded_model = model
|
||||
if vram_state == VRAMState.CPU:
|
||||
|
||||
if vram_set_state == VRAMState.DISABLED:
|
||||
pass
|
||||
elif vram_state == VRAMState.MPS:
|
||||
mps_device = torch.device("mps")
|
||||
real_model.to(mps_device)
|
||||
pass
|
||||
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
real_model.to(get_torch_device())
|
||||
else:
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
elif vram_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
|
||||
model_accelerated = True
|
||||
@ -197,7 +271,7 @@ def load_model_gpu(model):
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.CPU:
|
||||
if vram_state == VRAMState.DISABLED:
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
@ -233,22 +307,6 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
@ -258,7 +316,8 @@ def get_autocast_device(dev):
|
||||
def xformers_enabled():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if vram_state == VRAMState.CPU:
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if xpu_available:
|
||||
return False
|
||||
@ -330,12 +389,12 @@ def maximum_batch_area():
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.CPU
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
||||
def mps_mode():
|
||||
global vram_state
|
||||
return vram_state == VRAMState.MPS
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.MPS
|
||||
|
||||
def should_use_fp16():
|
||||
global xpu_available
|
||||
@ -367,7 +426,10 @@ def should_use_fp16():
|
||||
|
||||
def soft_empty_cache():
|
||||
global xpu_available
|
||||
if xpu_available:
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif xpu_available:
|
||||
torch.xpu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
||||
|
||||
@ -2,17 +2,26 @@ import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def prepare_noise(latent_image, seed, skip=0):
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
for _ in range(skip):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return noise
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
|
||||
@ -6,6 +6,10 @@ import contextlib
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||
return False
|
||||
s1 = c1['c_crossattn'].shape
|
||||
s2 = c2['c_crossattn'].shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
c = x['c_crossattn']
|
||||
if crossattn_max_len == 0:
|
||||
crossattn_max_len = c.shape[1]
|
||||
else:
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
c_crossattn.append(c)
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
c_crossattn_out = []
|
||||
for c in c_crossattn:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
if len(c_adm) > 0:
|
||||
@ -362,19 +386,8 @@ def resolve_cond_masks(conditions, h, w, device):
|
||||
else:
|
||||
box = boxes[0]
|
||||
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||
# Make sure the height and width are divisible by 8
|
||||
if X % 8 != 0:
|
||||
newx = X // 8 * 8
|
||||
W = W + (X - newx)
|
||||
X = newx
|
||||
if Y % 8 != 0:
|
||||
newy = Y // 8 * 8
|
||||
H = H + (Y - newy)
|
||||
Y = newy
|
||||
if H % 8 != 0:
|
||||
H = H + (8 - (H % 8))
|
||||
if W % 8 != 0:
|
||||
W = W + (8 - (W % 8))
|
||||
H = max(8, H)
|
||||
W = max(8, W)
|
||||
area = (int(H), int(W), int(Y), int(X))
|
||||
modified['area'] = area
|
||||
|
||||
@ -482,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -519,6 +532,8 @@ class KSampler:
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
elif self.scheduler == "simple":
|
||||
|
||||
66
comfy/sd.py
66
comfy/sd.py
@ -14,6 +14,7 @@ from .t2i_adapter import adapter
|
||||
from . import utils
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
from . import diffusers_convert
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
@ -324,15 +325,29 @@ def model_lora_keys(model, key_map={}):
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, size=0):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = []
|
||||
self.backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model)
|
||||
n = ModelPatcher(self.model, self.size)
|
||||
n.patches = self.patches[:]
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
@ -553,10 +568,16 @@ class VAE:
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
if ckpt_path is not None:
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
@ -630,12 +651,9 @@ class VAE:
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
||||
target_batch_size = target_latent_tensor.shape[0]
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
print(current_batch_size, target_batch_size)
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@ -652,7 +670,7 @@ def resize_image_to(tensor, target_latent_tensor, batched_number):
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model, device=None):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
@ -661,6 +679,7 @@ class ControlNet:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt, batched_number):
|
||||
control_prev = None
|
||||
@ -672,7 +691,9 @@ class ControlNet:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
if self.control_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
@ -694,6 +715,9 @@ class ControlNet:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control[i]
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
x = x.to(output_dtype)
|
||||
@ -724,7 +748,7 @@ class ControlNet:
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
return c
|
||||
@ -772,7 +796,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
else:
|
||||
@ -789,7 +813,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
use_linear_in_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
legacy=False,
|
||||
use_fp16=use_fp16)
|
||||
if pth:
|
||||
@ -819,7 +843,11 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
|
||||
control = ControlNet(control_model)
|
||||
global_average_pooling = False
|
||||
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
return control
|
||||
|
||||
class T2IAdapter:
|
||||
@ -843,10 +871,14 @@ class T2IAdapter:
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint)
|
||||
self.t2i_model.cpu()
|
||||
@ -1070,7 +1102,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
}
|
||||
|
||||
unet_config = {
|
||||
"use_checkpoint": True,
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"attention_resolutions": [
|
||||
|
||||
@ -56,7 +56,12 @@ class Downsample(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
if not self.use_conv:
|
||||
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
||||
self.op.padding = padding
|
||||
|
||||
x = self.op(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
@ -46,6 +51,88 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
with open(safetensors_path, "rb") as f:
|
||||
header = f.read(8)
|
||||
length_of_header = struct.unpack('<Q', header)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||
|
||||
#normalize
|
||||
b1_normalized = b1 / b1_norms
|
||||
b2_normalized = b2 / b2_norms
|
||||
|
||||
#zero when norms are zero
|
||||
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||
|
||||
#slerp
|
||||
dot = (b1_normalized*b2_normalized).sum(1)
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
|
||||
#technically not mathematically correct, but more pleasing?
|
||||
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new):
|
||||
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
|
||||
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
|
||||
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
|
||||
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
|
||||
ratios = ratios.movedim(1, -1).reshape((-1,1))
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
@ -61,7 +148,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CA_layer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(CA_layer, self).__init__()
|
||||
# global average pooling
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(self.gap(x))
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class Simple_CA_layer(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super(Simple_CA_layer, self).__init__()
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.fc(self.gap(x))
|
||||
|
||||
|
||||
class ECA_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ECA_MaxPool_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_MaxPool_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.max_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSA.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch import einsum, nn
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, length=1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class Conv_PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Gated_Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = int(dim * mult)
|
||||
|
||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=hidden_features * 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
# MBConv
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate=0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce("b c h w -> b c", "mean"),
|
||||
nn.Linear(dim, hidden_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias=False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange("b c -> b c 1 1"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob=0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0.0 or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = (
|
||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||
> self.prob
|
||||
)
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
|
||||
def MBConv(
|
||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(
|
||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||
),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
# nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout=dropout)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# attention related classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
if self.with_pe:
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, "c i j -> (i j) c")
|
||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||
grid, "j ... -> 1 j ..."
|
||||
)
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = (
|
||||
*x.shape,
|
||||
x.device,
|
||||
self.heads,
|
||||
)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# add positional bias
|
||||
if self.with_pe:
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(
|
||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||
)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||
|
||||
|
||||
class Block_Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.ps = window_size
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
# project for queries, keys, values
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||
h=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# attention
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(
|
||||
out,
|
||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||
x=h // self.ps,
|
||||
y=w // self.ps,
|
||||
head=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
)
|
||||
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention_grid(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention_grid, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class OSA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
ffn_bias=True,
|
||||
window_size=8,
|
||||
with_pe=False,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(OSA_Block, self).__init__()
|
||||
|
||||
w = window_size
|
||||
|
||||
self.layer = nn.Sequential(
|
||||
MBConv(
|
||||
channel_num,
|
||||
channel_num,
|
||||
downsample=False,
|
||||
expansion_rate=1,
|
||||
shrinkage_rate=0.25,
|
||||
),
|
||||
Rearrange(
|
||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # block-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
Rearrange(
|
||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # grid-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention_grid(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return out
|
||||
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSAG.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .esa import ESA
|
||||
from .OSA import OSA_Block
|
||||
|
||||
|
||||
class OSAG(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
block_num=4,
|
||||
ffn_bias=False,
|
||||
window_size=0,
|
||||
pe=False,
|
||||
):
|
||||
super(OSAG, self).__init__()
|
||||
|
||||
# print("window_size: %d" % (window_size))
|
||||
# print("with_pe", pe)
|
||||
# print("ffn_bias: %d" % (ffn_bias))
|
||||
|
||||
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||
|
||||
# script_name = "." + block_script_name
|
||||
# package = __import__(script_name, fromlist=True)
|
||||
block_class = OSA_Block # getattr(package, block_class_name)
|
||||
group_list = []
|
||||
for _ in range(block_num):
|
||||
temp_res = block_class(
|
||||
channel_num,
|
||||
bias,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=window_size,
|
||||
with_pe=pe,
|
||||
)
|
||||
group_list.append(temp_res)
|
||||
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||
self.residual_layer = nn.Sequential(*group_list)
|
||||
esa_channel = max(channel_num // 4, 16)
|
||||
self.esa = ESA(esa_channel, channel_num)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.residual_layer(x)
|
||||
out = out + x
|
||||
return self.esa(out)
|
||||
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OmniSR.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .OSAG import OSAG
|
||||
from .pixelshuffle import pixelshuffle_block
|
||||
|
||||
|
||||
class OmniSR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
):
|
||||
super(OmniSR, self).__init__()
|
||||
self.state = state_dict
|
||||
|
||||
bias = True # Fine to assume this for now
|
||||
block_num = 1 # Fine to assume this for now
|
||||
ffn_bias = True
|
||||
pe = True
|
||||
|
||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||
|
||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||
if up_scale - int(up_scale) > 0:
|
||||
print(
|
||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||
)
|
||||
up_scale = int(up_scale)
|
||||
res_num = 0
|
||||
for key in state_dict.keys():
|
||||
if "residual_layer" in key:
|
||||
temp_res_num = int(key.split(".")[1])
|
||||
if temp_res_num > res_num:
|
||||
res_num = temp_res_num
|
||||
res_num = res_num + 1 # zero-indexed
|
||||
|
||||
residual_layer = []
|
||||
self.res_num = res_num
|
||||
|
||||
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||
self.up_scale = up_scale
|
||||
|
||||
for _ in range(res_num):
|
||||
temp_res = OSAG(
|
||||
channel_num=num_feat,
|
||||
bias=bias,
|
||||
block_num=block_num,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=self.window_size,
|
||||
pe=pe,
|
||||
)
|
||||
residual_layer.append(temp_res)
|
||||
self.residual_layer = nn.Sequential(*residual_layer)
|
||||
self.input = nn.Conv2d(
|
||||
in_channels=num_in_ch,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.output = nn.Conv2d(
|
||||
in_channels=num_feat,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||
|
||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||
|
||||
# chaiNNer specific stuff
|
||||
self.model_arch = "OmniSR"
|
||||
self.sub_type = "SR"
|
||||
self.in_nc = num_in_ch
|
||||
self.out_nc = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.scale = up_scale
|
||||
|
||||
self.supports_fp16 = True # TODO: Test this
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = 16
|
||||
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
# import pdb; pdb.set_trace()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
residual = self.input(x)
|
||||
out = self.residual_layer(residual)
|
||||
|
||||
# origin
|
||||
out = torch.add(self.output(out), residual)
|
||||
out = self.up(out)
|
||||
|
||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||
return out
|
||||
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: esa.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
|
||||
def moment(x, dim=(2, 3), k=2):
|
||||
assert len(x.size()) == 4
|
||||
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||
return mk
|
||||
|
||||
|
||||
class ESA(nn.Module):
|
||||
"""
|
||||
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||
are deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||
super(ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
c1 = self.conv2(c1_)
|
||||
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||
c3 = self.conv3(v_max)
|
||||
c3 = F.interpolate(
|
||||
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||
)
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(c3 + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA_LN(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA_LN, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.norm = LayerNorm2d(n_feats)
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.norm(x)
|
||||
c1_ = self.conv1(c1_)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaGuidedFilter, self).__init__()
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=n_feats,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.r = 5
|
||||
|
||||
def box_filter(self, x, r):
|
||||
channel = x.shape[1]
|
||||
kernel_size = 2 * r + 1
|
||||
weight = 1.0 / (kernel_size**2)
|
||||
box_kernel = weight * torch.ones(
|
||||
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||
)
|
||||
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
N = self.box_filter(
|
||||
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||
)
|
||||
|
||||
# epsilon = self.fc(self.gap(x))
|
||||
# epsilon = torch.pow(epsilon, 2)
|
||||
epsilon = 1e-2
|
||||
|
||||
mean_x = self.box_filter(x, self.r) / N
|
||||
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||
|
||||
A = var_x / (var_x + epsilon)
|
||||
b = (1 - A) * mean_x
|
||||
m = A * x + b
|
||||
|
||||
# mean_A = self.box_filter(A, self.r) / N
|
||||
# mean_b = self.box_filter(b, self.r) / N
|
||||
# m = mean_A * x + mean_b
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaConvGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaConvGuidedFilter, self).__init__()
|
||||
f = esa_channels
|
||||
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.vec_conv(x)
|
||||
y = self.hor_conv(y)
|
||||
|
||||
sigma = torch.pow(y, 2)
|
||||
epsilon = self.fc(self.gap(y))
|
||||
|
||||
weight = sigma / (sigma + epsilon)
|
||||
|
||||
m = weight * x + (1 - weight)
|
||||
|
||||
return x * m
|
||||
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: layernorm.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return (
|
||||
gx,
|
||||
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
"""GRN (Global Response Normalization) layer"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: pixelshuffle.py
|
||||
# Created Date: Friday July 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||
):
|
||||
"""
|
||||
Upsample features according to `upscale_factor`.
|
||||
"""
|
||||
padding = kernel_size // 2
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * (upscale_factor**2),
|
||||
kernel_size,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
return nn.Sequential(*[conv, pixel_shuffle])
|
||||
@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
|
||||
self.scale: int = self.get_scale()
|
||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||
|
||||
c2x2 = False
|
||||
if self.state["model.0.weight"].shape[-2] == 2:
|
||||
c2x2 = True
|
||||
self.scale = round(math.sqrt(self.scale / 4))
|
||||
self.model_arch = "ESRGAN-2c2"
|
||||
|
||||
self.supports_fp16 = True
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = None
|
||||
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
|
||||
out_nc=self.num_filters,
|
||||
upscale_factor=3,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
else:
|
||||
upsample_blocks = [
|
||||
upsample_block(
|
||||
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(int(math.log(self.scale, 2)))
|
||||
]
|
||||
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
B.ShortcutBlock(
|
||||
B.sequential(
|
||||
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
|
||||
act_type=self.act,
|
||||
mode="CNA",
|
||||
plus=self.plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(self.num_blocks)
|
||||
],
|
||||
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
|
||||
norm_type=self.norm,
|
||||
act_type=None,
|
||||
mode=self.mode,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
),
|
||||
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
# hr_conv1
|
||||
B.conv_block(
|
||||
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -141,6 +141,19 @@ def sequential(*args):
|
||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||
|
||||
|
||||
# 2x2x2 Conv Block
|
||||
def conv_block_2c2(
|
||||
in_nc,
|
||||
out_nc,
|
||||
act_type="relu",
|
||||
):
|
||||
return sequential(
|
||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||
act(act_type) if act_type else None,
|
||||
)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
@ -153,12 +166,17 @@ def conv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type: str | None = "relu",
|
||||
mode: ConvMode = "CNA",
|
||||
c2x2=False,
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
|
||||
if c2x2:
|
||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||
|
||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||
@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
||||
_convtype="Conv2D",
|
||||
_spectral_norm=False,
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(
|
||||
@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB2 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB3 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nf + gc,
|
||||
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nf + 2 * gc,
|
||||
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nf + 3 * gc,
|
||||
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=last_act,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -499,6 +527,7 @@ def upconv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
mode="nearest",
|
||||
c2x2=False,
|
||||
):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
@ -512,5 +541,6 @@ def upconv_block(
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
state_dict = state_dict["params"]
|
||||
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
# SRVGGNet Real-ESRGAN (v2)
|
||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||
model = RealESRGANv2(state_dict)
|
||||
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
# MAT
|
||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||
model = MAT(state_dict)
|
||||
# Omni-SR
|
||||
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||
model = OmniSR(state_dict)
|
||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
||||
from .architecture.Swin2SR import Swin2SR
|
||||
from .architecture.SwinIR import SwinIR
|
||||
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
|
||||
PyTorchSRModel = Union[
|
||||
RealESRGANv2,
|
||||
SPSR,
|
||||
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
|
||||
SwinIR,
|
||||
Swin2SR,
|
||||
HAT,
|
||||
OmniSR,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -72,7 +72,7 @@ class MaskToImage:
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
|
||||
class ImageToMask:
|
||||
@ -167,7 +167,7 @@ class MaskComposite:
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract"],),
|
||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,6 +193,12 @@ class MaskComposite:
|
||||
output[top:bottom, left:right] = destination_portion + source_portion
|
||||
elif operation == "subtract":
|
||||
output[top:bottom, left:right] = destination_portion - source_portion
|
||||
elif operation == "and":
|
||||
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "or":
|
||||
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
elif operation == "xor":
|
||||
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
|
||||
@ -59,6 +59,12 @@ class Blend:
|
||||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
class Blur:
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -88,12 +94,6 @@ class Blur:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||
if blur_radius == 0:
|
||||
return (image,)
|
||||
@ -101,10 +101,11 @@ class Blur:
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||
blurred = blurred.permute(0, 2, 3, 1)
|
||||
|
||||
return (blurred,)
|
||||
@ -167,9 +168,15 @@ class Sharpen:
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
@ -181,21 +188,21 @@ class Sharpen:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||
if sharpen_radius == 0:
|
||||
return (image,)
|
||||
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
||||
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel_size**2
|
||||
kernel *= alpha
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
||||
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||
|
||||
result = torch.clamp(sharpened, 0, 1)
|
||||
|
||||
108
comfy_extras/nodes_rebatch.py
Normal file
108
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,108 @@
|
||||
import torch
|
||||
|
||||
class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
@staticmethod
|
||||
def get_batch(latents, list_ind, offset):
|
||||
'''prepare a batch out of the list of latents'''
|
||||
samples = latents[list_ind]['samples']
|
||||
shape = samples.shape
|
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||
if mask.shape[0] < samples.shape[0]:
|
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||
if 'batch_index' in latents[list_ind]:
|
||||
batch_inds = latents[list_ind]['batch_index']
|
||||
else:
|
||||
batch_inds = [x+offset for x in range(shape[0])]
|
||||
return samples, mask, batch_inds
|
||||
|
||||
@staticmethod
|
||||
def get_slices(indexable, num, batch_size):
|
||||
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||
slices = []
|
||||
for i in range(num):
|
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||
if num * batch_size < len(indexable):
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
return list(zip(*result))
|
||||
|
||||
@staticmethod
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
|
||||
for i in range(len(latents)):
|
||||
# fetch new entry of list
|
||||
#samples, masks, indices = self.get_batch(latents, i)
|
||||
next_batch = self.get_batch(latents, i, processed)
|
||||
processed += len(next_batch[2])
|
||||
# set to current if current is None
|
||||
if current_batch[0] is None:
|
||||
current_batch = next_batch
|
||||
# add previous to list if dimensions do not match
|
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
current_batch = next_batch
|
||||
# cat if everything checks out
|
||||
else:
|
||||
current_batch = self.cat_batch(current_batch, next_batch)
|
||||
|
||||
# add to list if dimensions gone above target batch size
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
current_batch = remainder
|
||||
|
||||
#add remainder
|
||||
if current_batch[0] is not None:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
|
||||
#get rid of empty masks
|
||||
for s in output_list:
|
||||
if s['noise_mask'].mean() == 1.0:
|
||||
del s['noise_mask']
|
||||
|
||||
return (output_list,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RebatchLatents": LatentRebatch,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
||||
@ -17,7 +17,7 @@ class UpscaleModelLoader:
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
out = model_loading.load_state_dict(sd).eval()
|
||||
return (out, )
|
||||
|
||||
|
||||
559
execution.py
559
execution.py
@ -6,6 +6,7 @@ import threading
|
||||
import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
@ -26,27 +27,96 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = input_data
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = prompt
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = extra_data['extra_pnginfo']
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = unique_id
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
return None
|
||||
elif isinstance(x, (int, float, bool, str)):
|
||||
return x
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -55,23 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id]:
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
input_data_formatted = {}
|
||||
if input_data_all is not None:
|
||||
input_data_formatted = {}
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
}
|
||||
return (False, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -105,7 +216,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
@ -144,10 +256,53 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def execute(self, prompt, extra_data={}):
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
node_id = error["node_id"]
|
||||
class_type = prompt[node_id]["class_type"]
|
||||
|
||||
# First, send back the status to the frontend depending
|
||||
# on the exception type
|
||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
}
|
||||
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
|
||||
else:
|
||||
if self.server.client_id is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
}
|
||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@ -155,6 +310,10 @@ class PromptExecutor:
|
||||
else:
|
||||
self.server.client_id = None
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
@ -169,105 +328,250 @@ class PromptExecutor:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in prompt:
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
to_execute += [(0, x)]
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
x = to_execute.pop(0)[-1]
|
||||
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
if class_.OUTPUT_NODE == True:
|
||||
valid = False
|
||||
try:
|
||||
m = validate_inputs(prompt, x)
|
||||
valid = m[0]
|
||||
except:
|
||||
valid = False
|
||||
if valid:
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
finally:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item):
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
for x in required_inputs:
|
||||
if x not in inputs:
|
||||
return (False, "Required input is missing. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
return (False, "Bad Input. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "bad_linked_input",
|
||||
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||
r = validate_inputs(prompt, o_id)
|
||||
if r[0] == False:
|
||||
return r
|
||||
received_type = r[val[1]]
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_type": received_type,
|
||||
"linked_node": val
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
try:
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] is False:
|
||||
# `r` will be set in `validated[o_id]` already
|
||||
valid = False
|
||||
continue
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_inner_validation",
|
||||
"message": "Exception when validating inner node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"linked_node": val
|
||||
}
|
||||
}]
|
||||
validated[o_id] = (False, reasons, o_id)
|
||||
continue
|
||||
else:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
try:
|
||||
if type_input == "INT":
|
||||
val = int(val)
|
||||
inputs[x] = val
|
||||
if type_input == "FLOAT":
|
||||
val = float(val)
|
||||
inputs[x] = val
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
"message": f"Failed to convert an input value to a {type_input} value",
|
||||
"details": f"{x}, {val}, {ex}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
"exception_message": str(ex)
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
return (False, "Value smaller than min. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
if ret != True:
|
||||
return (False, "{}, {}".format(class_type, ret))
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
return (True, "")
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
# Don't send back gigantic lists like if they're lots of
|
||||
# scanned model filepaths
|
||||
if len(type_input) > 20:
|
||||
list_info = f"(list of length {len(type_input)})"
|
||||
input_config = None
|
||||
else:
|
||||
list_info = str(type_input)
|
||||
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
else:
|
||||
ret = (True, [], unique_id)
|
||||
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
def full_type_name(klass):
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
outputs = set()
|
||||
@ -277,34 +581,86 @@ def validate_prompt(prompt):
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
return (False, "Prompt has no outputs")
|
||||
error = {
|
||||
"type": "prompt_no_outputs",
|
||||
"message": "Prompt has no outputs",
|
||||
"details": "",
|
||||
"extra_info": {}
|
||||
}
|
||||
return (False, error, [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
node_errors = {}
|
||||
validated = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reason = ""
|
||||
reasons = []
|
||||
try:
|
||||
m = validate_inputs(prompt, o)
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
reasons = m[1]
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
exception_type = full_type_name(typ)
|
||||
reasons = [{
|
||||
"type": "exception_during_validation",
|
||||
"message": "Exception when validating node",
|
||||
"details": str(ex),
|
||||
"extra_info": {
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb)
|
||||
}
|
||||
}]
|
||||
validated[o] = (False, reasons, o)
|
||||
|
||||
if valid == True:
|
||||
good_outputs.add(x)
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||
print("output will be ignored")
|
||||
errors += [(o, reason)]
|
||||
print(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
print("* (prompt):")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
reasons = result[1]
|
||||
# If a node upstream has errors, the nodes downstream will also
|
||||
# be reported as invalid, but there will be no errors attached.
|
||||
# So don't return those nodes as having errors in the response.
|
||||
if valid is not True and len(reasons) > 0:
|
||||
if node_id not in node_errors:
|
||||
class_type = prompt[node_id]['class_type']
|
||||
node_errors[node_id] = {
|
||||
"errors": reasons,
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
print(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||
errors_list = []
|
||||
for o, errors in errors:
|
||||
for error in errors:
|
||||
errors_list.append(f"{error['message']}: {error['details']}")
|
||||
errors_list = "\n".join(errors_list)
|
||||
|
||||
return (True, "")
|
||||
error = {
|
||||
"type": "prompt_outputs_failed_validation",
|
||||
"message": "Prompt outputs failed validation",
|
||||
"details": errors_list,
|
||||
"extra_info": {}
|
||||
}
|
||||
|
||||
return (False, error, list(good_outputs), node_errors)
|
||||
|
||||
return (True, None, list(good_outputs), node_errors)
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
@ -340,8 +696,7 @@ class PromptQueue:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
if "ui" in outputs[o]:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
||||
@ -1,14 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions.add('.safetensors')
|
||||
supported_pt_extensions.add('.safetensors')
|
||||
except:
|
||||
print("Could not import safetensors, safetensors support disabled.")
|
||||
|
||||
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
|
||||
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
|
||||
|
||||
folder_names_and_paths = {}
|
||||
|
||||
@ -38,6 +32,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
os.makedirs(input_directory)
|
||||
|
||||
@ -118,12 +114,18 @@ def get_folder_paths(folder_name):
|
||||
return folder_names_and_paths[folder_name][0][:]
|
||||
|
||||
def recursive_search(directory):
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
result = []
|
||||
dirs = {directory: os.path.getmtime(directory)}
|
||||
for root, subdir, file in os.walk(directory, followlinks=True):
|
||||
for filepath in file:
|
||||
#we os.path,join directory with a blank string to generate a path separator at the end.
|
||||
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
|
||||
return result
|
||||
for d in subdir:
|
||||
path = os.path.join(root, d)
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
@ -132,19 +134,90 @@ def filter_files_extensions(files, extensions):
|
||||
|
||||
def get_full_path(folder_name, filename):
|
||||
global folder_names_and_paths
|
||||
if folder_name not in folder_names_and_paths:
|
||||
return None
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
filename = os.path.relpath(os.path.join("/", filename), "/")
|
||||
for x in folders[0]:
|
||||
full_path = os.path.join(x, filename)
|
||||
if os.path.isfile(full_path):
|
||||
return full_path
|
||||
|
||||
return None
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
def get_filename_list_(folder_name):
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
output_folders = {}
|
||||
for x in folders[0]:
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
files, folders_all = recursive_search(x)
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
||||
|
||||
def cached_filename_list_(folder_name):
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
if time.perf_counter() < (out[2] + 0.5):
|
||||
return out
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
if os.path.isdir(x):
|
||||
if x not in out[1]:
|
||||
return None
|
||||
|
||||
return out
|
||||
|
||||
def get_filename_list(folder_name):
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
return list(out[0])
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input, image_width, image_height):
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
4
main.py
4
main.py
@ -33,8 +33,8 @@ def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
e.execute(item[-2], item[-1])
|
||||
q.task_done(item_id, e.outputs)
|
||||
e.execute(item[2], item[1], item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
270
nodes.py
270
nodes.py
@ -6,16 +6,18 @@ import json
|
||||
import hashlib
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
import comfy.diffusers_convert
|
||||
import comfy.diffusers_load
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
import comfy.sd
|
||||
@ -28,6 +30,7 @@ import importlib
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
@ -145,9 +148,6 @@ class ConditioningSetMask:
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -160,9 +160,6 @@ class VAEDecode:
|
||||
return (vae.decode(samples["samples"]), )
|
||||
|
||||
class VAEDecodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -175,9 +172,6 @@ class VAEDecodeTiled:
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
|
||||
class VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -202,9 +196,6 @@ class VAEEncode:
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -219,9 +210,6 @@ class VAEEncodeTiled:
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
||||
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ),
|
||||
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
# support save metadata for latent sharing
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
file = f"{filename}_{counter:05}_.latent"
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
samples = {"samples": latent["latent_tensor"].float()}
|
||||
return (samples, )
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, latent):
|
||||
image_path = folder_paths.get_annotated_filepath(latent)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, latent):
|
||||
if not folder_paths.exists_annotated_filepath(latent):
|
||||
return "Invalid latent file: {}".format(latent)
|
||||
return True
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -296,7 +359,10 @@ class DiffusersLoader:
|
||||
paths = []
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths += next(os.walk(search_path))[1]
|
||||
for root, subdir, files in os.walk(search_path, followlinks=True):
|
||||
if "model_index.json" in files:
|
||||
paths.append(os.path.relpath(root, start=search_path))
|
||||
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -306,12 +372,12 @@ class DiffusersLoader:
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths = next(os.walk(search_path))[1]
|
||||
if model_path in paths:
|
||||
model_path = os.path.join(search_path, model_path)
|
||||
path = os.path.join(search_path, model_path)
|
||||
if os.path.exists(path):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
@ -360,6 +426,9 @@ class LoraLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
|
||||
return (model_lora, clip_lora)
|
||||
@ -517,9 +586,11 @@ class ControlNetApply:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
print(control_hint.shape)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
||||
@ -624,6 +695,9 @@ class unCLIPConditioning:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
o = t[1].copy()
|
||||
@ -706,22 +780,61 @@ class LatentFromBatch:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "rotate"
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
CATEGORY = "latent"
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def rotate(self, samples, batch_index):
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
||||
s["batch_index"] = batch_index
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] == 1:
|
||||
s["noise_mask"] = masks.clone()
|
||||
else:
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||
if "batch_index" not in s:
|
||||
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@classmethod
|
||||
@ -740,6 +853,25 @@ class LatentUpscale:
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
return (s,)
|
||||
|
||||
class LatentUpscaleBy:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
|
||||
"scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, scale_by):
|
||||
s = samples.copy()
|
||||
width = round(samples["samples"].shape[3] * scale_by)
|
||||
height = round(samples["samples"].shape[2] * scale_by)
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||
return (s,)
|
||||
|
||||
class LatentRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -872,7 +1004,7 @@ class SetLatentNoiseMask:
|
||||
|
||||
def set_mask(self, samples, mask):
|
||||
s = samples.copy()
|
||||
s["noise_mask"] = mask
|
||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
@ -882,8 +1014,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
@ -978,39 +1110,7 @@ class SaveImage:
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input):
|
||||
input = input.replace("%width%", str(images[0].shape[1]))
|
||||
input = input.replace("%height%", str(images[0].shape[0]))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
@ -1049,8 +1149,9 @@ class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), )},
|
||||
{"image": (sorted(files), )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -1060,6 +1161,7 @@ class LoadImage:
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
@ -1090,9 +1192,10 @@ class LoadImageMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), ),
|
||||
"channel": (s._color_channels, ),}
|
||||
{"image": (sorted(files), ),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
@ -1102,6 +1205,7 @@ class LoadImageMask:
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
@ -1244,7 +1348,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAELoader": VAELoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentUpscaleBy": LatentUpscaleBy,
|
||||
"LatentFromBatch": LatentFromBatch,
|
||||
"RepeatLatentBatch": RepeatLatentBatch,
|
||||
"SaveImage": SaveImage,
|
||||
"PreviewImage": PreviewImage,
|
||||
"LoadImage": LoadImage,
|
||||
@ -1282,6 +1388,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1319,7 +1428,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentCrop": "Crop Latent",
|
||||
"EmptyLatentImage": "Empty Latent Image",
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentUpscaleBy": "Upscale Latent By",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
"SaveImage": "Save Image",
|
||||
"PreviewImage": "Preview Image",
|
||||
@ -1351,14 +1463,18 @@ def load_custom_node(module_path):
|
||||
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
else:
|
||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||
return False
|
||||
|
||||
def load_custom_nodes():
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
if "__pycache__" in possible_modules:
|
||||
@ -1367,11 +1483,25 @@ def load_custom_nodes():
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
load_custom_node(module_path)
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path)
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
print("\nImport times for custom nodes:")
|
||||
for n in sorted(node_import_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (IMPORT FAILED)"
|
||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
print()
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_nodes()
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
load_custom_nodes()
|
||||
|
||||
@ -175,6 +175,8 @@
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
@ -183,7 +185,9 @@
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end='')\n",
|
||||
|
||||
250
server.py
250
server.py
@ -7,6 +7,9 @@ import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
@ -19,7 +22,8 @@ except ImportError:
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
@web.middleware
|
||||
async def cache_control(request: web.Request, handler):
|
||||
@ -78,7 +82,7 @@ class PromptServer():
|
||||
# Reusing existing session, remove old
|
||||
self.sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
self.sockets[sid] = ws
|
||||
|
||||
@ -110,49 +114,96 @@ class PromptServer():
|
||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
def get_dir_by_type(dir_type):
|
||||
if dir_type is None:
|
||||
dir_type = "input"
|
||||
|
||||
if dir_type == "input":
|
||||
type_dir = folder_paths.get_input_directory()
|
||||
elif dir_type == "temp":
|
||||
type_dir = folder_paths.get_temp_directory()
|
||||
elif dir_type == "output":
|
||||
type_dir = folder_paths.get_output_directory()
|
||||
|
||||
return type_dir, dir_type
|
||||
|
||||
def image_upload(post, image_save_function=None):
|
||||
image = post.get("image")
|
||||
overwrite = post.get("overwrite")
|
||||
|
||||
if post.get("type") is None:
|
||||
upload_dir = folder_paths.get_input_directory()
|
||||
elif post.get("type") == "input":
|
||||
upload_dir = folder_paths.get_input_directory()
|
||||
elif post.get("type") == "temp":
|
||||
upload_dir = folder_paths.get_temp_directory()
|
||||
elif post.get("type") == "output":
|
||||
upload_dir = folder_paths.get_output_directory()
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
image_upload_type = post.get("type")
|
||||
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||||
|
||||
if image and image.file:
|
||||
filename = image.filename
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
subfolder = post.get("subfolder", "")
|
||||
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
|
||||
|
||||
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
|
||||
return web.Response(status=400)
|
||||
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder)
|
||||
|
||||
split = os.path.splitext(filename)
|
||||
i = 1
|
||||
while os.path.exists(os.path.join(upload_dir, filename)):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
i += 1
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
|
||||
filepath = os.path.join(upload_dir, filename)
|
||||
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||||
pass
|
||||
else:
|
||||
i = 1
|
||||
while os.path.exists(filepath):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
i += 1
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename})
|
||||
if image_save_function is not None:
|
||||
image_save_function(image, post, filepath)
|
||||
else:
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
return image_upload(post)
|
||||
|
||||
@routes.post("/upload/mask")
|
||||
async def upload_mask(request):
|
||||
post = await request.post()
|
||||
|
||||
def image_save_function(image, post, filepath):
|
||||
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
||||
mask_pil = Image.open(image.file).convert('RGBA')
|
||||
|
||||
# alpha copy
|
||||
new_alpha = mask_pil.getchannel('A')
|
||||
original_pil.putalpha(new_alpha)
|
||||
original_pil.save(filepath, compress_level=4)
|
||||
|
||||
return image_upload(post, image_save_function)
|
||||
|
||||
@routes.get("/view")
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
@ -162,35 +213,132 @@ class PromptServer():
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
else:
|
||||
channel = request.rel_url.query["channel"]
|
||||
|
||||
if channel == 'rgb':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
r, g, b, a = img.split()
|
||||
new_img = Image.merge('RGB', (r, g, b))
|
||||
else:
|
||||
new_img = img.convert("RGB")
|
||||
|
||||
buffer = BytesIO()
|
||||
new_img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
_, _, _, a = img.split()
|
||||
else:
|
||||
a = Image.new('L', img.size, 255)
|
||||
|
||||
# alpha img
|
||||
alpha_img = Image.new('RGBA', img.size)
|
||||
alpha_img.putalpha(a)
|
||||
alpha_buffer = BytesIO()
|
||||
alpha_img.save(alpha_buffer, format='PNG')
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/view_metadata/{folder_name}")
|
||||
async def view_metadata(request):
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
if not filename.endswith(".safetensors"):
|
||||
return web.Response(status=404)
|
||||
|
||||
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||||
if safetensors_path is None:
|
||||
return web.Response(status=404)
|
||||
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def get_queue(request):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||||
system_stats = {
|
||||
"devices": [
|
||||
{
|
||||
"name": device_name,
|
||||
"type": device.type,
|
||||
"index": device.index,
|
||||
"vram_total": vram_total,
|
||||
"vram_free": vram_free,
|
||||
"torch_vram_total": torch_vram_total,
|
||||
"torch_vram_free": torch_vram_free,
|
||||
}
|
||||
]
|
||||
}
|
||||
return web.json_response(system_stats)
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = x
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
out[x] = info
|
||||
out[x] = node_info(x)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
node_class = request.match_info.get("node_class", None)
|
||||
out = {}
|
||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/history")
|
||||
@ -232,14 +380,16 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
return web.json_response({"prompt_id": prompt_id, "number": number})
|
||||
else:
|
||||
resp_code = 400
|
||||
out_string = valid[1]
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
return web.Response(body=out_string, status=resp_code)
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
json_data = await request.json()
|
||||
@ -249,9 +399,9 @@ class PromptServer():
|
||||
if "delete" in json_data:
|
||||
to_delete = json_data['delete']
|
||||
for id_to_delete in to_delete:
|
||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
||||
delete_func = lambda a: a[1] == id_to_delete
|
||||
self.prompt_queue.delete_queue_item(delete_func)
|
||||
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/interrupt")
|
||||
@ -275,7 +425,7 @@ class PromptServer():
|
||||
def add_routes(self):
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
web.static('/', self.web_root, follow_symlinks=True),
|
||||
])
|
||||
|
||||
def get_queue_info(self):
|
||||
|
||||
166
web/extensions/core/clipspace.js
Normal file
166
web/extensions/core/clipspace.js
Normal file
@ -0,0 +1,166 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
|
||||
export class ClipspaceDialog extends ComfyDialog {
|
||||
static items = [];
|
||||
static instance = null;
|
||||
|
||||
static registerButton(name, contextPredicate, callback) {
|
||||
const item =
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: name,
|
||||
contextPredicate: contextPredicate,
|
||||
onclick: callback
|
||||
})
|
||||
|
||||
ClipspaceDialog.items.push(item);
|
||||
}
|
||||
|
||||
static invalidatePreview() {
|
||||
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
if(img_preview) {
|
||||
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
img_preview.style.maxHeight = "100%";
|
||||
img_preview.style.maxWidth = "100%";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static invalidate() {
|
||||
if(ClipspaceDialog.instance) {
|
||||
const self = ClipspaceDialog.instance;
|
||||
// allow reconstruct controls when copying from non-image to image content.
|
||||
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
|
||||
|
||||
if(self.element) {
|
||||
// update
|
||||
self.element.removeChild(self.element.firstChild);
|
||||
self.element.appendChild(children);
|
||||
}
|
||||
else {
|
||||
// new
|
||||
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
|
||||
}
|
||||
|
||||
if(self.element.children[0].children.length <= 1) {
|
||||
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
|
||||
}
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
}
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
|
||||
createButtons(self) {
|
||||
const buttons = [];
|
||||
|
||||
for(let idx in ClipspaceDialog.items) {
|
||||
const item = ClipspaceDialog.items[idx];
|
||||
if(!item.contextPredicate || item.contextPredicate())
|
||||
buttons.push(ClipspaceDialog.items[idx]);
|
||||
}
|
||||
|
||||
buttons.push(
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
onclick: () => { this.close(); }
|
||||
})
|
||||
);
|
||||
|
||||
return buttons;
|
||||
}
|
||||
|
||||
createImgSettings() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
const combo_items = [];
|
||||
const imgs = ComfyApp.clipspace.imgs;
|
||||
|
||||
for(let i=0; i < imgs.length; i++) {
|
||||
combo_items.push($el("option", {value:i}, [`${i}`]));
|
||||
}
|
||||
|
||||
const combo1 = $el("select",
|
||||
{id:"clipspace_img_selector", onchange:(event) => {
|
||||
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
} }, combo_items);
|
||||
|
||||
const row1 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
|
||||
$el("td", {}, [combo1])
|
||||
]);
|
||||
|
||||
|
||||
const combo2 = $el("select",
|
||||
{id:"clipspace_img_paste_mode", onchange:(event) => {
|
||||
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
|
||||
} },
|
||||
[
|
||||
$el("option", {value:'selected'}, 'selected'),
|
||||
$el("option", {value:'all'}, 'all')
|
||||
]);
|
||||
combo2.value = ComfyApp.clipspace['img_paste_mode'];
|
||||
|
||||
const row2 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
|
||||
$el("td", {}, [combo2])
|
||||
]);
|
||||
|
||||
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
|
||||
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
|
||||
|
||||
const row3 =
|
||||
$el("tr", {}, [td]);
|
||||
|
||||
return $el("table", {}, [row1, row2, row3]);
|
||||
}
|
||||
else {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
createImgPreview() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
|
||||
}
|
||||
else
|
||||
return [];
|
||||
}
|
||||
|
||||
show() {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
ClipspaceDialog.invalidate();
|
||||
|
||||
this.element.style.display = "block";
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.Clipspace",
|
||||
init(app) {
|
||||
app.openClipspace =
|
||||
function () {
|
||||
if(!ClipspaceDialog.instance) {
|
||||
ClipspaceDialog.instance = new ClipspaceDialog(app);
|
||||
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace) {
|
||||
ClipspaceDialog.instance.show();
|
||||
}
|
||||
else
|
||||
app.ui.dialog.show("Clipspace is Empty!");
|
||||
};
|
||||
}
|
||||
});
|
||||
@ -174,7 +174,7 @@ const els = {}
|
||||
// const ctxMenu = LiteGraph.ContextMenu;
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
init() {
|
||||
addCustomNodeDefs(node_defs) {
|
||||
const sortObjectKeys = (unordered) => {
|
||||
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||
obj[key] = unordered[key];
|
||||
@ -182,10 +182,10 @@ app.registerExtension({
|
||||
}, {});
|
||||
};
|
||||
|
||||
const getSlotTypes = async () => {
|
||||
function getSlotTypes() {
|
||||
var types = [];
|
||||
|
||||
const defs = await api.getNodeDefs();
|
||||
const defs = node_defs;
|
||||
for (const nodeId in defs) {
|
||||
const nodeData = defs[nodeId];
|
||||
|
||||
@ -212,8 +212,8 @@ app.registerExtension({
|
||||
return types;
|
||||
};
|
||||
|
||||
const completeColorPalette = async (colorPalette) => {
|
||||
var types = await getSlotTypes();
|
||||
function completeColorPalette(colorPalette) {
|
||||
var types = getSlotTypes();
|
||||
|
||||
for (const type of types) {
|
||||
if (!colorPalette.colors.node_slot[type]) {
|
||||
|
||||
648
web/extensions/core/maskeditor.js
Normal file
648
web/extensions/core/maskeditor.js
Normal file
@ -0,0 +1,648 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
|
||||
|
||||
// Helper function to convert a data URL to a Blob object
|
||||
function dataURLToBlob(dataURL) {
|
||||
const parts = dataURL.split(';base64,');
|
||||
const contentType = parts[0].split(':')[1];
|
||||
const byteString = atob(parts[1]);
|
||||
const arrayBuffer = new ArrayBuffer(byteString.length);
|
||||
const uint8Array = new Uint8Array(arrayBuffer);
|
||||
for (let i = 0; i < byteString.length; i++) {
|
||||
uint8Array[i] = byteString.charCodeAt(i);
|
||||
}
|
||||
return new Blob([arrayBuffer], { type: contentType });
|
||||
}
|
||||
|
||||
function loadedImageToBlob(image) {
|
||||
const canvas = document.createElement('canvas');
|
||||
|
||||
canvas.width = image.width;
|
||||
canvas.height = image.height;
|
||||
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
ctx.drawImage(image, 0, 0);
|
||||
|
||||
const dataURL = canvas.toDataURL('image/png', 1);
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
async function uploadMask(filepath, formData) {
|
||||
await fetch('/upload/mask', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
}).then(response => {}).catch(error => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
|
||||
function prepareRGB(image, backupCanvas, backupCtx) {
|
||||
// paste mask data into alpha channel
|
||||
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
}
|
||||
|
||||
class MaskEditorDialog extends ComfyDialog {
|
||||
static instance = null;
|
||||
|
||||
static getInstance() {
|
||||
if(!MaskEditorDialog.instance) {
|
||||
MaskEditorDialog.instance = new MaskEditorDialog(app);
|
||||
}
|
||||
|
||||
return MaskEditorDialog.instance;
|
||||
}
|
||||
|
||||
is_layout_created = false;
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.element = $el("div.comfy-modal", { parent: document.body },
|
||||
[ $el("div.comfy-modal-content",
|
||||
[...this.createButtons()]),
|
||||
]);
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [];
|
||||
}
|
||||
|
||||
createButton(name, callback) {
|
||||
var button = document.createElement("button");
|
||||
button.innerText = name;
|
||||
button.addEventListener("click", callback);
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "left";
|
||||
button.style.marginRight = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createRightButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "right";
|
||||
button.style.marginLeft = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftSlider(self, name, callback) {
|
||||
const divElement = document.createElement('div');
|
||||
divElement.id = "maskeditor-slider";
|
||||
divElement.style.cssFloat = "left";
|
||||
divElement.style.fontFamily = "sans-serif";
|
||||
divElement.style.marginRight = "4px";
|
||||
divElement.style.color = "var(--input-text)";
|
||||
divElement.style.backgroundColor = "var(--comfy-input-bg)";
|
||||
divElement.style.borderRadius = "8px";
|
||||
divElement.style.borderColor = "var(--border-color)";
|
||||
divElement.style.borderStyle = "solid";
|
||||
divElement.style.fontSize = "15px";
|
||||
divElement.style.height = "21px";
|
||||
divElement.style.padding = "1px 6px";
|
||||
divElement.style.display = "flex";
|
||||
divElement.style.position = "relative";
|
||||
divElement.style.top = "2px";
|
||||
self.brush_slider_input = document.createElement('input');
|
||||
self.brush_slider_input.setAttribute('type', 'range');
|
||||
self.brush_slider_input.setAttribute('min', '1');
|
||||
self.brush_slider_input.setAttribute('max', '100');
|
||||
self.brush_slider_input.setAttribute('value', '10');
|
||||
const labelElement = document.createElement("label");
|
||||
labelElement.textContent = name;
|
||||
|
||||
divElement.appendChild(labelElement);
|
||||
divElement.appendChild(self.brush_slider_input);
|
||||
|
||||
self.brush_slider_input.addEventListener("change", callback);
|
||||
|
||||
return divElement;
|
||||
}
|
||||
|
||||
setlayout(imgCanvas, maskCanvas) {
|
||||
const self = this;
|
||||
|
||||
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||
var placeholder = document.createElement("div");
|
||||
placeholder.style.position = "relative";
|
||||
placeholder.style.height = "50px";
|
||||
|
||||
var bottom_panel = document.createElement("div");
|
||||
bottom_panel.style.position = "absolute";
|
||||
bottom_panel.style.bottom = "0px";
|
||||
bottom_panel.style.left = "20px";
|
||||
bottom_panel.style.right = "20px";
|
||||
bottom_panel.style.height = "50px";
|
||||
|
||||
var brush = document.createElement("div");
|
||||
brush.id = "brush";
|
||||
brush.style.backgroundColor = "transparent";
|
||||
brush.style.outline = "1px dashed black";
|
||||
brush.style.boxShadow = "0 0 0 1px white";
|
||||
brush.style.borderRadius = "50%";
|
||||
brush.style.MozBorderRadius = "50%";
|
||||
brush.style.WebkitBorderRadius = "50%";
|
||||
brush.style.position = "absolute";
|
||||
brush.style.zIndex = 8889;
|
||||
brush.style.pointerEvents = "none";
|
||||
this.brush = brush;
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
document.body.appendChild(brush);
|
||||
|
||||
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
self.brush_size = event.target.value;
|
||||
self.updateBrushPreview(self, null, null);
|
||||
});
|
||||
var clearButton = this.createLeftButton("Clear",
|
||||
() => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
||||
});
|
||||
var cancelButton = this.createRightButton("Cancel", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.close();
|
||||
});
|
||||
|
||||
this.saveButton = this.createRightButton("Save", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.save();
|
||||
});
|
||||
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
|
||||
bottom_panel.appendChild(clearButton);
|
||||
bottom_panel.appendChild(this.saveButton);
|
||||
bottom_panel.appendChild(cancelButton);
|
||||
bottom_panel.appendChild(brush_size_slider);
|
||||
|
||||
imgCanvas.style.position = "relative";
|
||||
imgCanvas.style.top = "200";
|
||||
imgCanvas.style.left = "0";
|
||||
|
||||
maskCanvas.style.position = "absolute";
|
||||
}
|
||||
|
||||
show() {
|
||||
if(!this.is_layout_created) {
|
||||
// layout
|
||||
const imgCanvas = document.createElement('canvas');
|
||||
const maskCanvas = document.createElement('canvas');
|
||||
const backupCanvas = document.createElement('canvas');
|
||||
|
||||
imgCanvas.id = "imageCanvas";
|
||||
maskCanvas.id = "maskCanvas";
|
||||
backupCanvas.id = "backupCanvas";
|
||||
|
||||
this.setlayout(imgCanvas, maskCanvas);
|
||||
|
||||
// prepare content
|
||||
this.imgCanvas = imgCanvas;
|
||||
this.maskCanvas = maskCanvas;
|
||||
this.backupCanvas = backupCanvas;
|
||||
this.maskCtx = maskCanvas.getContext('2d');
|
||||
this.backupCtx = backupCanvas.getContext('2d');
|
||||
|
||||
this.setEventHandler(maskCanvas);
|
||||
|
||||
this.is_layout_created = true;
|
||||
|
||||
// replacement of onClose hook since close is not real close
|
||||
const self = this;
|
||||
const observer = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutation) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||
ComfyApp.onClipspaceEditorClosed();
|
||||
}
|
||||
|
||||
self.last_display_style = self.element.style.display;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const config = { attributes: true };
|
||||
observer.observe(this.element, config);
|
||||
}
|
||||
|
||||
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
this.saveButton.innerText = "Save to node";
|
||||
}
|
||||
else {
|
||||
this.saveButton.innerText = "Save";
|
||||
}
|
||||
this.saveButton.disabled = false;
|
||||
|
||||
this.element.style.display = "block";
|
||||
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||
}
|
||||
|
||||
isOpened() {
|
||||
return this.element.style.display == "block";
|
||||
}
|
||||
|
||||
setImages(imgCanvas, backupCanvas) {
|
||||
const imgCtx = imgCanvas.getContext('2d');
|
||||
const backupCtx = backupCanvas.getContext('2d');
|
||||
const maskCtx = this.maskCtx;
|
||||
const maskCanvas = this.maskCanvas;
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||
|
||||
// image load
|
||||
const orig_image = new Image();
|
||||
window.addEventListener("resize", () => {
|
||||
// repositioning
|
||||
imgCanvas.width = window.innerWidth - 250;
|
||||
imgCanvas.height = window.innerHeight - 200;
|
||||
|
||||
// redraw image
|
||||
let drawWidth = orig_image.width;
|
||||
let drawHeight = orig_image.height;
|
||||
if (orig_image.width > imgCanvas.width) {
|
||||
drawWidth = imgCanvas.width;
|
||||
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
||||
}
|
||||
|
||||
if (drawHeight > imgCanvas.height) {
|
||||
drawHeight = imgCanvas.height;
|
||||
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
||||
}
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||
|
||||
// update mask
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCanvas.width = drawWidth;
|
||||
maskCanvas.height = drawHeight;
|
||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
});
|
||||
|
||||
const filepath = ComfyApp.clipspace.images;
|
||||
|
||||
const touched_image = new Image();
|
||||
|
||||
touched_image.onload = function() {
|
||||
backupCanvas.width = touched_image.width;
|
||||
backupCanvas.height = touched_image.height;
|
||||
|
||||
prepareRGB(touched_image, backupCanvas, backupCtx);
|
||||
};
|
||||
|
||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||
alpha_url.searchParams.delete('channel');
|
||||
alpha_url.searchParams.set('channel', 'a');
|
||||
touched_image.src = alpha_url;
|
||||
|
||||
// original image load
|
||||
orig_image.onload = function() {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
};
|
||||
|
||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||
rgb_url.searchParams.delete('channel');
|
||||
rgb_url.searchParams.set('channel', 'rgb');
|
||||
orig_image.src = rgb_url;
|
||||
this.image = orig_image;
|
||||
}
|
||||
|
||||
setEventHandler(maskCanvas) {
|
||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||
event.preventDefault();
|
||||
});
|
||||
|
||||
const self = this;
|
||||
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||
}
|
||||
|
||||
brush_size = 10;
|
||||
drawing_mode = false;
|
||||
lastx = -1;
|
||||
lasty = -1;
|
||||
lasttime = 0;
|
||||
|
||||
static handleKeyDown(event) {
|
||||
const self = MaskEditorDialog.instance;
|
||||
if (event.key === ']') {
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
} else if (event.key === '[') {
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
} else if(event.key === 'Enter') {
|
||||
self.save();
|
||||
}
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
static handlePointerUp(event) {
|
||||
event.preventDefault();
|
||||
MaskEditorDialog.instance.drawing_mode = false;
|
||||
}
|
||||
|
||||
updateBrushPreview(self) {
|
||||
const brush = self.brush;
|
||||
|
||||
var centerX = self.cursorX;
|
||||
var centerY = self.cursorY;
|
||||
|
||||
brush.style.width = self.brush_size * 2 + "px";
|
||||
brush.style.height = self.brush_size * 2 + "px";
|
||||
brush.style.left = (centerX - self.brush_size) + "px";
|
||||
brush.style.top = (centerY - self.brush_size) + "px";
|
||||
}
|
||||
|
||||
handleWheelEvent(self, event) {
|
||||
if(event.deltaY < 0)
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
else
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
draw_move(self, event) {
|
||||
event.preventDefault();
|
||||
|
||||
this.cursorX = event.pageX;
|
||||
this.cursorY = event.pageY;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
|
||||
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
|
||||
var diff = performance.now() - self.lasttime;
|
||||
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
|
||||
var x = event.offsetX;
|
||||
var y = event.offsetY
|
||||
|
||||
if(event.offsetX == null) {
|
||||
x = event.targetTouches[0].clientX - maskRect.left;
|
||||
}
|
||||
|
||||
if(event.offsetY == null) {
|
||||
y = event.targetTouches[0].clientY - maskRect.top;
|
||||
}
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !this.drawing_mode)
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
handlePointerDown(self, event) {
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
|
||||
if ([0, 2, 5].includes(event.button)) {
|
||||
self.drawing_mode = true;
|
||||
|
||||
event.preventDefault();
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
self.maskCtx.beginPath();
|
||||
if (event.button == 0) {
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
} else {
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
}
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
async save() {
|
||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
backupCtx.drawImage(this.maskCanvas,
|
||||
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// paste mask data into alpha channel
|
||||
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
|
||||
const formData = new FormData();
|
||||
const filename = "clipspace-mask-" + performance.now() + ".png";
|
||||
|
||||
const item =
|
||||
{
|
||||
"filename": filename,
|
||||
"subfolder": "clipspace",
|
||||
"type": "input",
|
||||
};
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[0] = item;
|
||||
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||
|
||||
if(index >= 0)
|
||||
ComfyApp.clipspace.widgets[index].value = item;
|
||||
}
|
||||
|
||||
const dataURL = this.backupCanvas.toDataURL();
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
const original_blob = loadedImageToBlob(this.image);
|
||||
|
||||
formData.append('image', blob, filename);
|
||||
formData.append('original_image', original_blob);
|
||||
formData.append('type', "input");
|
||||
formData.append('subfolder', "clipspace");
|
||||
|
||||
this.saveButton.innerText = "Saving...";
|
||||
this.saveButton.disabled = true;
|
||||
await uploadMask(item, formData);
|
||||
ComfyApp.onClipspaceEditorSave();
|
||||
this.close();
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.MaskEditor",
|
||||
init(app) {
|
||||
ComfyApp.open_maskeditor =
|
||||
function () {
|
||||
const dlg = MaskEditorDialog.getInstance();
|
||||
if(!dlg.isOpened()) {
|
||||
dlg.show();
|
||||
}
|
||||
};
|
||||
|
||||
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
||||
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
|
||||
}
|
||||
});
|
||||
@ -300,7 +300,7 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
if (widget.type === "number") {
|
||||
if (widget.type === "number" || widget.type === "combo") {
|
||||
addValueControlWidget(this, widget, "fixed");
|
||||
}
|
||||
|
||||
|
||||
@ -14,5 +14,5 @@
|
||||
window.graph = app.graph;
|
||||
</script>
|
||||
</head>
|
||||
<body></body>
|
||||
<body class="litegraph"></body>
|
||||
</html>
|
||||
|
||||
@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//when clicked on top of a node
|
||||
//and it is not interactive
|
||||
if (node && this.allow_interaction && !skip_action && !this.read_only) {
|
||||
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
|
||||
if (!this.live_mode && !node.flags.pinned) {
|
||||
this.bringToFront(node);
|
||||
} //if it wasn't selected?
|
||||
|
||||
//not dragging mouse to connect two slots
|
||||
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
//Search for corner for resize
|
||||
if ( !skip_action &&
|
||||
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
||||
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
|
||||
//double clicking
|
||||
if (is_double_click && this.selected_nodes[node.id]) {
|
||||
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
|
||||
//double click node
|
||||
if (node.onDblClick) {
|
||||
node.onDblClick( e, pos, this );
|
||||
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
if (this.dragging_rectangle)
|
||||
{
|
||||
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
||||
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.ds.offset[1] += delta[1] / this.ds.scale;
|
||||
this.dirty_canvas = true;
|
||||
this.dirty_bgcanvas = true;
|
||||
} else if (this.allow_interaction && !this.read_only) {
|
||||
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
|
||||
if (this.connecting_node) {
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
//remove mouseover flag
|
||||
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
||||
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
||||
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (show_text) {
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||
}
|
||||
break;
|
||||
case "toggle":
|
||||
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
if (show_text) {
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(
|
||||
w.name + " " + Number(w.value).toFixed(3),
|
||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||
widget_width * 0.5,
|
||||
y + H * 0.7
|
||||
);
|
||||
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
}
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
if (w.type == "number") {
|
||||
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//ctx.stroke();
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
event,
|
||||
active_widget
|
||||
) {
|
||||
if (!node.widgets || !node.widgets.length) {
|
||||
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
canvas.graph.add(group);
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines the furthest nodes in each direction
|
||||
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.getBoundaryNodes = function(nodes) {
|
||||
let top = null;
|
||||
let right = null;
|
||||
let bottom = null;
|
||||
let left = null;
|
||||
for (const nID in nodes) {
|
||||
const node = nodes[nID];
|
||||
const [x, y] = node.pos;
|
||||
const [width, height] = node.size;
|
||||
|
||||
if (top === null || y < top.pos[1]) {
|
||||
top = node;
|
||||
}
|
||||
if (right === null || x + width > right.pos[0] + right.size[0]) {
|
||||
right = node;
|
||||
}
|
||||
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
|
||||
bottom = node;
|
||||
}
|
||||
if (left === null || x < left.pos[0]) {
|
||||
left = node;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"top": top,
|
||||
"right": right,
|
||||
"bottom": bottom,
|
||||
"left": left
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Determines the furthest nodes in each direction for the currently selected nodes
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
|
||||
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {LGraphNode[]} nodes a list of nodes
|
||||
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
|
||||
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
|
||||
*/
|
||||
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
|
||||
if (!nodes) {
|
||||
return;
|
||||
}
|
||||
|
||||
const canvas = LGraphCanvas.active_canvas;
|
||||
let boundaryNodes = []
|
||||
if (align_to === undefined) {
|
||||
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
|
||||
} else {
|
||||
boundaryNodes = {
|
||||
"top": align_to,
|
||||
"right": align_to,
|
||||
"bottom": align_to,
|
||||
"left": align_to
|
||||
}
|
||||
}
|
||||
|
||||
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
|
||||
switch (direction) {
|
||||
case "right":
|
||||
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
|
||||
break;
|
||||
case "left":
|
||||
node.pos[0] = boundaryNodes["left"].pos[0];
|
||||
break;
|
||||
case "top":
|
||||
node.pos[1] = boundaryNodes["top"].pos[1];
|
||||
break;
|
||||
case "bottom":
|
||||
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
canvas.dirty_canvas = true;
|
||||
canvas.dirty_bgcanvas = true;
|
||||
};
|
||||
|
||||
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
||||
|
||||
var canvas = LGraphCanvas.active_canvas;
|
||||
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
||||
}*/
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onGroupAlign,
|
||||
})
|
||||
}
|
||||
|
||||
if (this._graph_stack && this._graph_stack.length > 0) {
|
||||
options.push(null, {
|
||||
content: "Close subgraph",
|
||||
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
callback: LGraphCanvas.onMenuNodeToSubgraph
|
||||
});
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align Selected To",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onNodeAlign,
|
||||
})
|
||||
}
|
||||
|
||||
options.push(null, {
|
||||
content: "Remove",
|
||||
disabled: !(node.removable !== false && !node.block_delete ),
|
||||
|
||||
@ -88,6 +88,12 @@ class ComfyApi extends EventTarget {
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_start":
|
||||
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
|
||||
break;
|
||||
case "execution_error":
|
||||
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
|
||||
break;
|
||||
default:
|
||||
if (this.#registered.has(msg.type)) {
|
||||
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
|
||||
@ -163,7 +169,7 @@ class ComfyApi extends EventTarget {
|
||||
|
||||
if (res.status !== 200) {
|
||||
throw {
|
||||
response: await res.text(),
|
||||
response: await res.json(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
|
||||
/**
|
||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||
@ -25,6 +25,9 @@ export class ComfyApp {
|
||||
* @type {serialized node object}
|
||||
*/
|
||||
static clipspace = null;
|
||||
static clipspace_invalidate_handler = null;
|
||||
static open_maskeditor = null;
|
||||
static clipspace_return_node = null;
|
||||
|
||||
constructor() {
|
||||
this.ui = new ComfyUI(this);
|
||||
@ -48,6 +51,114 @@ export class ComfyApp {
|
||||
this.shiftDown = false;
|
||||
}
|
||||
|
||||
static isImageNode(node) {
|
||||
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||
}
|
||||
|
||||
static onClipspaceEditorSave() {
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
|
||||
}
|
||||
}
|
||||
|
||||
static onClipspaceEditorClosed() {
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
}
|
||||
|
||||
static copyToClipspace(node) {
|
||||
var widgets = null;
|
||||
if(node.widgets) {
|
||||
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
var imgs = undefined;
|
||||
var orig_imgs = undefined;
|
||||
if(node.imgs != undefined) {
|
||||
imgs = [];
|
||||
orig_imgs = [];
|
||||
|
||||
for (let i = 0; i < node.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = node.imgs[i].src;
|
||||
orig_imgs[i] = imgs[i];
|
||||
}
|
||||
}
|
||||
|
||||
var selectedIndex = 0;
|
||||
if(node.imageIndex) {
|
||||
selectedIndex = node.imageIndex;
|
||||
}
|
||||
|
||||
ComfyApp.clipspace = {
|
||||
'widgets': widgets,
|
||||
'imgs': imgs,
|
||||
'original_imgs': orig_imgs,
|
||||
'images': node.images,
|
||||
'selectedIndex': selectedIndex,
|
||||
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
||||
};
|
||||
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
|
||||
if(ComfyApp.clipspace_invalidate_handler) {
|
||||
ComfyApp.clipspace_invalidate_handler();
|
||||
}
|
||||
}
|
||||
|
||||
static pasteFromClipspace(node) {
|
||||
if(ComfyApp.clipspace) {
|
||||
// image paste
|
||||
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||
if(node.images && ComfyApp.clipspace.images) {
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||
}
|
||||
else
|
||||
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
// deep-copy to cut link with clipspace
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
const img = new Image();
|
||||
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
node.imgs = [img];
|
||||
node.imageIndex = 0;
|
||||
}
|
||||
else {
|
||||
const imgs = [];
|
||||
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
||||
node.imgs = imgs;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(node.widgets) {
|
||||
if(ComfyApp.clipspace.images) {
|
||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index >= 0) {
|
||||
node.widgets[index].value = clip_image;
|
||||
}
|
||||
}
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop && prop.type != 'button') {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke an extension callback
|
||||
* @param {keyof ComfyExtension} method The extension callback to execute
|
||||
@ -137,81 +248,30 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
options.push(
|
||||
{
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => {
|
||||
var widgets = null;
|
||||
if(this.widgets) {
|
||||
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
let img = new Image();
|
||||
var imgs = undefined;
|
||||
if(this.imgs != undefined) {
|
||||
img.src = this.imgs[0].src;
|
||||
imgs = [img];
|
||||
}
|
||||
// prevent conflict of clipspace content
|
||||
if(!ComfyApp.clipspace_return_node) {
|
||||
options.push({
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||
});
|
||||
|
||||
ComfyApp.clipspace = {
|
||||
'widgets': widgets,
|
||||
'imgs': imgs,
|
||||
'original_imgs': imgs,
|
||||
'images': this.images
|
||||
};
|
||||
}
|
||||
});
|
||||
if(ComfyApp.clipspace != null) {
|
||||
options.push({
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||
});
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace != null) {
|
||||
options.push(
|
||||
{
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => {
|
||||
if(ComfyApp.clipspace != null) {
|
||||
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop) {
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// image paste
|
||||
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
|
||||
var filename = "";
|
||||
if(this.images && ComfyApp.clipspace.images) {
|
||||
this.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.images != undefined) {
|
||||
const clip_image = ComfyApp.clipspace.images[0];
|
||||
if(clip_image.subfolder != '')
|
||||
filename = `${clip_image.subfolder}/`;
|
||||
filename += `${clip_image.filename} [${clip_image.type}]`;
|
||||
}
|
||||
else if(ComfyApp.clipspace.widgets != undefined) {
|
||||
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index_in_clip >= 0) {
|
||||
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
|
||||
}
|
||||
}
|
||||
|
||||
const index = this.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
|
||||
this.imgs = ComfyApp.clipspace.imgs;
|
||||
|
||||
this.widgets[index].value = filename;
|
||||
if(this.widgets_values != undefined) {
|
||||
this.widgets_values[index] = filename;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.trigger('changed');
|
||||
if(ComfyApp.isImageNode(this)) {
|
||||
options.push({
|
||||
content: "Open in MaskEditor",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
ComfyApp.clipspace_return_node = this;
|
||||
ComfyApp.open_maskeditor();
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -711,16 +771,27 @@ export class ComfyApp {
|
||||
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
|
||||
const res = origDrawNodeShape.apply(this, arguments);
|
||||
|
||||
const nodeErrors = self.lastPromptError?.node_errors[node.id];
|
||||
|
||||
let color = null;
|
||||
let lineWidth = 1;
|
||||
if (node.id === +self.runningNodeId) {
|
||||
color = "#0f0";
|
||||
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
|
||||
color = "dodgerblue";
|
||||
}
|
||||
else if (self.lastPromptError != null && nodeErrors?.errors) {
|
||||
color = "red";
|
||||
lineWidth = 2;
|
||||
}
|
||||
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
|
||||
color = "#f0f";
|
||||
lineWidth = 2;
|
||||
}
|
||||
|
||||
if (color) {
|
||||
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
|
||||
ctx.lineWidth = 1;
|
||||
ctx.lineWidth = lineWidth;
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.beginPath();
|
||||
if (shape == LiteGraph.BOX_SHAPE)
|
||||
@ -747,11 +818,28 @@ export class ComfyApp {
|
||||
ctx.stroke();
|
||||
ctx.strokeStyle = fgcolor;
|
||||
ctx.globalAlpha = 1;
|
||||
}
|
||||
|
||||
if (self.progress) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
if (self.progress && node.id === +self.runningNodeId) {
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
||||
ctx.fillStyle = bgcolor;
|
||||
}
|
||||
|
||||
// Highlight inputs that failed validation
|
||||
if (nodeErrors) {
|
||||
ctx.lineWidth = 2;
|
||||
ctx.strokeStyle = "red";
|
||||
for (const error of nodeErrors.errors) {
|
||||
if (error.extra_info && error.extra_info.input_name) {
|
||||
const inputIndex = node.findInputSlot(error.extra_info.input_name)
|
||||
if (inputIndex !== -1) {
|
||||
let pos = node.getConnectionPos(true, inputIndex);
|
||||
ctx.beginPath();
|
||||
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
|
||||
ctx.stroke();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -809,6 +897,17 @@ export class ComfyApp {
|
||||
}
|
||||
});
|
||||
|
||||
api.addEventListener("execution_start", ({ detail }) => {
|
||||
this.lastExecutionError = null
|
||||
});
|
||||
|
||||
api.addEventListener("execution_error", ({ detail }) => {
|
||||
this.lastExecutionError = detail;
|
||||
const formattedError = this.#formatExecutionError(detail);
|
||||
this.ui.dialog.show(formattedError);
|
||||
this.canvas.draw(true, true);
|
||||
});
|
||||
|
||||
api.init();
|
||||
}
|
||||
|
||||
@ -842,7 +941,9 @@ export class ComfyApp {
|
||||
await this.#loadExtensions();
|
||||
|
||||
// Create and mount the LiteGraph in the DOM
|
||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
||||
const mainCanvas = document.createElement("canvas")
|
||||
mainCanvas.style.touchAction = "none"
|
||||
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
@ -909,6 +1010,11 @@ export class ComfyApp {
|
||||
const app = this;
|
||||
// Load node definitions from the backend
|
||||
const defs = await api.getNodeDefs();
|
||||
await this.registerNodesFromDefs(defs);
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
async registerNodesFromDefs(defs) {
|
||||
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
||||
|
||||
// Generate list of known widgets
|
||||
@ -954,7 +1060,8 @@ export class ComfyApp {
|
||||
for (const o in nodeData["output"]) {
|
||||
const output = nodeData["output"][o];
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
this.addOutput(outputName, output);
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
@ -980,8 +1087,6 @@ export class ComfyApp {
|
||||
LiteGraph.registerNodeType(nodeId, node);
|
||||
node.category = nodeData.category;
|
||||
}
|
||||
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1180,6 +1285,43 @@ export class ComfyApp {
|
||||
return { workflow, output };
|
||||
}
|
||||
|
||||
#formatPromptError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
else if (typeof error === "string") {
|
||||
return error;
|
||||
}
|
||||
else if (error.stack && error.message) {
|
||||
return error.toString()
|
||||
}
|
||||
else if (error.response) {
|
||||
let message = error.response.error.message;
|
||||
if (error.response.error.details)
|
||||
message += ": " + error.response.error.details;
|
||||
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
|
||||
message += "\n" + nodeError.class_type + ":"
|
||||
for (const errorReason of nodeError.errors) {
|
||||
message += "\n - " + errorReason.message + ": " + errorReason.details
|
||||
}
|
||||
}
|
||||
return message
|
||||
}
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
#formatExecutionError(error) {
|
||||
if (error == null) {
|
||||
return "(unknown error)"
|
||||
}
|
||||
|
||||
const traceback = error.traceback.join("")
|
||||
const nodeId = error.node_id
|
||||
const nodeType = error.node_type
|
||||
|
||||
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
|
||||
}
|
||||
|
||||
async queuePrompt(number, batchCount = 1) {
|
||||
this.#queueItems.push({ number, batchCount });
|
||||
|
||||
@ -1187,8 +1329,10 @@ export class ComfyApp {
|
||||
if (this.#processingQueue) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
this.#processingQueue = true;
|
||||
this.lastPromptError = null;
|
||||
|
||||
try {
|
||||
while (this.#queueItems.length) {
|
||||
({ number, batchCount } = this.#queueItems.pop());
|
||||
@ -1199,7 +1343,12 @@ export class ComfyApp {
|
||||
try {
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.ui.dialog.show(error.response || error.toString());
|
||||
const formattedError = this.#formatPromptError(error)
|
||||
this.ui.dialog.show(formattedError);
|
||||
if (error.response) {
|
||||
this.lastPromptError = error.response;
|
||||
this.canvas.draw(true, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1245,6 +1394,11 @@ export class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1273,14 +1427,19 @@ export class ComfyApp {
|
||||
|
||||
const def = defs[node.type];
|
||||
|
||||
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
|
||||
// and additional work is needed to consider the primitive logic in the refresh logic.
|
||||
if(!def)
|
||||
continue;
|
||||
|
||||
for(const widgetNum in node.widgets) {
|
||||
const widget = node.widgets[widgetNum]
|
||||
|
||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||
widget.options.values = def["input"]["required"][widget.name][0];
|
||||
|
||||
if(!widget.options.values.includes(widget.value)) {
|
||||
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
|
||||
widget.value = widget.options.values[0];
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1292,6 +1451,8 @@ export class ComfyApp {
|
||||
*/
|
||||
clean() {
|
||||
this.nodeOutputs = {};
|
||||
this.lastPromptError = null;
|
||||
this.lastExecutionError = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -47,12 +47,29 @@ export function getPngMetadata(file) {
|
||||
});
|
||||
}
|
||||
|
||||
export function getLatentMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
const safetensorsData = new Uint8Array(event.target.result);
|
||||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
export async function importA1111(graph, parameters) {
|
||||
const p = parameters.lastIndexOf("\nSteps:");
|
||||
if (p > -1) {
|
||||
const embeddings = await api.getEmbeddings();
|
||||
const opts = parameters
|
||||
.substr(p)
|
||||
.split("\n")[1]
|
||||
.split(",")
|
||||
.reduce((p, n) => {
|
||||
const s = n.split(":");
|
||||
|
||||
@ -465,7 +465,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png",
|
||||
accept: ".json,image/png,.latent",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
@ -581,6 +581,7 @@ export class ComfyUI {
|
||||
}),
|
||||
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
||||
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
|
||||
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||
app.clean();
|
||||
|
||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
|
||||
var v = valueControl.value;
|
||||
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
switch (v) {
|
||||
case "increment":
|
||||
current_index += 1;
|
||||
break;
|
||||
case "decrement":
|
||||
current_index -= 1;
|
||||
break;
|
||||
case "randomize":
|
||||
current_index = Math.floor(Math.random() * current_length);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
} else { //number
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
return valueControl;
|
||||
};
|
||||
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
computeSize(node.size);
|
||||
}
|
||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||
const t = ctx.getTransform();
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
Object.assign(this.inputEl.style, {
|
||||
left: `${t.a * margin + t.e}px`,
|
||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
||||
background: (!node.color)?'':node.color,
|
||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
||||
transformOrigin: "0 0",
|
||||
transform: transform,
|
||||
left: "0px",
|
||||
top: "0px",
|
||||
width: `${widgetWidth - (margin * 2)}px`,
|
||||
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||
position: "absolute",
|
||||
background: (!node.color)?'':node.color,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
fontSize: `${t.d * 10.0}px`,
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
},
|
||||
@ -266,10 +297,46 @@ export const ComfyWidgets = {
|
||||
node.imgs = [img];
|
||||
app.graph.setDirtyCanvas(true);
|
||||
};
|
||||
img.src = `/view?filename=${name}&type=input`;
|
||||
let folder_separator = name.lastIndexOf("/");
|
||||
let subfolder = "";
|
||||
if (folder_separator > -1) {
|
||||
subfolder = name.substring(0, folder_separator);
|
||||
name = name.substring(folder_separator + 1);
|
||||
}
|
||||
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
|
||||
node.setSizeForImage?.();
|
||||
}
|
||||
|
||||
var default_value = imageWidget.value;
|
||||
Object.defineProperty(imageWidget, "value", {
|
||||
set : function(value) {
|
||||
this._real_value = value;
|
||||
},
|
||||
|
||||
get : function() {
|
||||
let value = "";
|
||||
if (this._real_value) {
|
||||
value = this._real_value;
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
|
||||
if (value.filename) {
|
||||
let real_value = value;
|
||||
value = "";
|
||||
if (real_value.subfolder) {
|
||||
value = real_value.subfolder + "/";
|
||||
}
|
||||
|
||||
value += real_value.filename;
|
||||
|
||||
if(real_value.type && real_value.type !== "input")
|
||||
value += ` [${real_value.type}]`;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
});
|
||||
|
||||
// Add our own callback to the combo widget to render an image when it changes
|
||||
const cb = node.callback;
|
||||
imageWidget.callback = function () {
|
||||
|
||||
@ -39,6 +39,8 @@ body {
|
||||
padding: 2px;
|
||||
resize: none;
|
||||
border: none;
|
||||
box-sizing: border-box;
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.comfy-modal {
|
||||
@ -287,6 +289,11 @@ button.comfy-queue-btn {
|
||||
|
||||
/* Context menu */
|
||||
|
||||
.litegraph .dialog {
|
||||
z-index: 1;
|
||||
font-family: Arial;
|
||||
}
|
||||
|
||||
.litegraph .litemenu-entry.has_submenu {
|
||||
position: relative;
|
||||
padding-right: 20px;
|
||||
@ -329,6 +336,7 @@ button.comfy-queue-btn {
|
||||
z-index: 9999 !important;
|
||||
background-color: var(--comfy-menu-bg) !important;
|
||||
overflow: hidden;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.litegraph.litesearchbox input,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user