Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
Dr.Lt.Data 2023-06-04 00:42:47 +09:00 committed by GitHub
commit 23332731bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 4498 additions and 673 deletions

View File

@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'):
else:
raise AssertionError('Unknown merge analysis result')
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:

View File

@ -30,6 +30,7 @@ jobs:
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false
- shell: bash
run: |
cd ..

View File

@ -17,6 +17,7 @@ jobs:
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: '3.11.3'

View File

@ -1,14 +1,5 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

111
comfy/diffusers_load.py Normal file
View File

@ -0,0 +1,111 @@
import json
import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
import os.path as osp
import re
import torch
from safetensors.torch import load_file, save_file
import diffusers_convert
def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None):
diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json")))
diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json")))
# magic
v2 = diffusers_unet_conf["sample_size"] == 96
if 'prediction_type' in diffusers_scheduler_conf:
v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction'
if v2:
if v_pred:
config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml')
else:
config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml')
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
vae_config['scale_factor'] = scale_factor
model_config_params["unet_config"]["params"]["use_fp16"] = fp16
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
# Load models from safetensors if it exists, if it doesn't pytorch
if osp.exists(unet_path):
unet_state_dict = load_file(unet_path, device="cpu")
else:
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = load_file(vae_path, device="cpu")
else:
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = load_file(text_enc_path, device="cpu")
else:
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
# Convert the UNet model
unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
if is_v20_model:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict)
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
else:
text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
# Put together new checkpoint
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
clip = None
vae = None
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
load_state_dict_to = []
if output_vae:
vae = VAE(scale_factor=scale_factor, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_state_dict_to = [w]
if output_clip:
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config["model"])
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return ModelPatcher(model), clip, vae

View File

@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
return x+h
def slice_attention(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
scale = (int(q.shape[-1])**(-0.5))
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
return r1
class AttnBlock(nn.Module):
def __init__(self, in_channels):
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
# compute attention
b,c,h,w = q.shape
scale = (int(c)**(-0.5))
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = torch.zeros_like(k, device=q.device)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_)
return x+h_
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out

View File

@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
"""
B, N, _ = metric.shape
if r <= 0:
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

View File

@ -1,23 +1,29 @@
import psutil
from enum import Enum
from comfy.cli_args import args
import torch
class VRAMState(Enum):
CPU = 0
DISABLED = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
SHARED = 5
class CPUState(Enum):
GPU = 0
CPU = 1
MPS = 2
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
cpu_state = CPUState.GPU
total_vram = 0
total_vram_available_mb = -1
accelerate_enabled = False
lowvram_available = True
xpu_available = False
directml_enabled = False
@ -31,30 +37,80 @@ if args.directml is not None:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import torch
if directml_enabled:
total_vram = 4097 #TODO
else:
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
xpu_available = True
except:
pass
try:
if torch.backends.mps.is_available():
cpu_state = CPUState.MPS
except:
pass
if args.cpu:
cpu_state = CPUState.CPU
def get_torch_device():
global xpu_available
global directml_enabled
global cpu_state
if directml_enabled:
global directml_device
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
if cpu_state == CPUState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.device(torch.cuda.current_device())
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = VRAMState.HIGH_VRAM
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION:
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
lowvram_available = True
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
@ -102,32 +159,42 @@ if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
if lowvram_available:
try:
import accelerate
accelerate_enabled = True
vram_state = set_vram_to
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
try:
if torch.backends.mps.is_available():
vram_state = VRAMState.MPS
except:
pass
if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED
if args.cpu:
vram_state = VRAMState.CPU
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}")
def get_torch_device_name(device):
if hasattr(device, 'type'):
if device.type == "cuda":
return "{} {}".format(device, torch.cuda.get_device_name(device))
else:
return "{}".format(device.type)
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None
current_gpu_controlnets = []
@ -173,22 +240,29 @@ def load_model_gpu(model):
model.unpatch_model()
raise e
model.model_patches_to(get_torch_device())
torch_dev = get_torch_device()
model.model_patches_to(torch_dev)
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model
if vram_state == VRAMState.CPU:
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(get_torch_device())
else:
if vram_state == VRAMState.NO_VRAM:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
model_accelerated = True
@ -197,7 +271,7 @@ def load_model_gpu(model):
def load_controlnet_gpu(control_models):
global current_gpu_controlnets
global vram_state
if vram_state == VRAMState.CPU:
if vram_state == VRAMState.DISABLED:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
@ -233,22 +307,6 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
@ -258,7 +316,8 @@ def get_autocast_device(dev):
def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU:
global cpu_state
if cpu_state != CPUState.GPU:
return False
if xpu_available:
return False
@ -330,12 +389,12 @@ def maximum_batch_area():
return int(max(area, 0))
def cpu_mode():
global vram_state
return vram_state == VRAMState.CPU
global cpu_state
return cpu_state == CPUState.CPU
def mps_mode():
global vram_state
return vram_state == VRAMState.MPS
global cpu_state
return cpu_state == CPUState.MPS
def should_use_fp16():
global xpu_available
@ -367,7 +426,10 @@ def should_use_fp16():
def soft_empty_cache():
global xpu_available
if xpu_available:
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif xpu_available:
torch.xpu.empty_cache()
elif torch.cuda.is_available():
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda

View File

@ -2,17 +2,26 @@ import torch
import comfy.model_management
import comfy.samplers
import math
import numpy as np
def prepare_noise(latent_image, seed, skip=0):
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
for _ in range(skip):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""

View File

@ -6,6 +6,10 @@ import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
@ -362,19 +386,8 @@ def resolve_cond_masks(conditions, h, w, device):
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
H = max(8, H)
W = max(8, W)
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
@ -482,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
@ -519,6 +532,8 @@ class KSampler:
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":

View File

@ -14,6 +14,7 @@ from .t2i_adapter import adapter
from . import utils
from . import clip_vision
from . import gligen
from . import diffusers_convert
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
@ -324,15 +325,29 @@ def model_lora_keys(model, key_map={}):
return key_map
class ModelPatcher:
def __init__(self, model):
def __init__(self, model, size=0):
self.size = size
self.model = model
self.patches = []
self.backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
return size
def clone(self):
n = ModelPatcher(self.model)
n = ModelPatcher(self.model, self.size)
n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
return n
@ -553,10 +568,16 @@ class VAE:
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path)
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
else:
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
self.scale_factor = scale_factor
if device is None:
device = model_management.get_torch_device()
@ -630,12 +651,9 @@ class VAE:
samples = samples.cpu()
return samples
def resize_image_to(tensor, target_latent_tensor, batched_number):
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
target_batch_size = target_latent_tensor.shape[0]
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
print(current_batch_size, target_batch_size)
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -652,7 +670,7 @@ def resize_image_to(tensor, target_latent_tensor, batched_number):
return torch.cat([tensor] * batched_number, dim=0)
class ControlNet:
def __init__(self, control_model, device=None):
def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
@ -661,6 +679,7 @@ class ControlNet:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond_txt, batched_number):
control_prev = None
@ -672,7 +691,9 @@ class ControlNet:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
@ -694,6 +715,9 @@ class ControlNet:
key = 'output'
index = i
x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
@ -724,7 +748,7 @@ class ControlNet:
self.cond_hint = None
def copy(self):
c = ControlNet(self.control_model)
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
return c
@ -772,7 +796,7 @@ def load_controlnet(ckpt_path, model=None):
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
else:
@ -789,7 +813,7 @@ def load_controlnet(ckpt_path, model=None):
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
use_checkpoint=False,
legacy=False,
use_fp16=use_fp16)
if pth:
@ -819,7 +843,11 @@ def load_controlnet(ckpt_path, model=None):
if use_fp16:
control_model = control_model.half()
control = ControlNet(control_model)
global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
return control
class T2IAdapter:
@ -843,10 +871,14 @@ class T2IAdapter:
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.control_input = None
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu()
@ -1070,7 +1102,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
}
unet_config = {
"use_checkpoint": True,
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"attention_resolutions": [

View File

@ -56,7 +56,12 @@ class Downsample(nn.Module):
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
if not self.use_conv:
padding = [x.shape[2] % 2, x.shape[3] % 2]
self.op.padding = padding
x = self.op(x)
return x
class ResnetBlock(nn.Module):

View File

@ -1,11 +1,16 @@
import torch
import math
import struct
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
@ -46,6 +51,88 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd
def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
length_of_header = struct.unpack('<Q', header)[0]
if length_of_header > max_size:
return None
return f.read(length_of_header)
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
#norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
#normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#zero when norms are zero
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
#slerp
dot = (b1_normalized*b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
#technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c))
pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c))
ratios = ratios.movedim(1, -1).reshape((-1,1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
@ -61,7 +148,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if upscale_method == "bislerp":
return bislerp(s, width, height)
else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))

View File

@ -0,0 +1,110 @@
import math
import torch.nn as nn
class CA_layer(nn.Module):
def __init__(self, channel, reduction=16):
super(CA_layer, self).__init__()
# global average pooling
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
nn.GELU(),
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
# nn.Sigmoid()
)
def forward(self, x):
y = self.fc(self.gap(x))
return x * y.expand_as(x)
class Simple_CA_layer(nn.Module):
def __init__(self, channel):
super(Simple_CA_layer, self).__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(
in_channels=channel,
out_channels=channel,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)
def forward(self, x):
return x * self.fc(self.gap(x))
class ECA_layer(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel):
super(ECA_layer, self).__init__()
b = 1
gamma = 2
k_size = int(abs(math.log(channel, 2) + b) / gamma)
k_size = k_size if k_size % 2 else k_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
)
# self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: input features with shape [b, c, h, w]
# b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
# y = self.sigmoid(y)
return x * y.expand_as(x)
class ECA_MaxPool_layer(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel):
super(ECA_MaxPool_layer, self).__init__()
b = 1
gamma = 2
k_size = int(abs(math.log(channel, 2) + b) / gamma)
k_size = k_size if k_size % 2 else k_size + 1
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.conv = nn.Conv1d(
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
)
# self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: input features with shape [b, c, h, w]
# b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.max_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
# y = self.sigmoid(y)
return x * y.expand_as(x)

View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,577 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: OSA.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from torch import einsum, nn
from .layernorm import LayerNorm2d
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length=1):
return val if isinstance(val, tuple) else ((val,) * length)
# helper classes
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class Conv_PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm2d(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult=2, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Conv_FeedForward(nn.Module):
def __init__(self, dim, mult=2, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1, 0),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1, 1, 0),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Gated_Conv_FeedForward(nn.Module):
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
super().__init__()
hidden_features = int(dim * mult)
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(
hidden_features * 2,
hidden_features * 2,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features * 2,
bias=bias,
)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
# MBConv
class SqueezeExcitation(nn.Module):
def __init__(self, dim, shrinkage_rate=0.25):
super().__init__()
hidden_dim = int(dim * shrinkage_rate)
self.gate = nn.Sequential(
Reduce("b c h w -> b c", "mean"),
nn.Linear(dim, hidden_dim, bias=False),
nn.SiLU(),
nn.Linear(hidden_dim, dim, bias=False),
nn.Sigmoid(),
Rearrange("b c -> b c 1 1"),
)
def forward(self, x):
return x * self.gate(x)
class MBConvResidual(nn.Module):
def __init__(self, fn, dropout=0.0):
super().__init__()
self.fn = fn
self.dropsample = Dropsample(dropout)
def forward(self, x):
out = self.fn(x)
out = self.dropsample(out)
return out + x
class Dropsample(nn.Module):
def __init__(self, prob=0):
super().__init__()
self.prob = prob
def forward(self, x):
device = x.device
if self.prob == 0.0 or (not self.training):
return x
keep_mask = (
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
> self.prob
)
return x * keep_mask / (1 - self.prob)
def MBConv(
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
):
hidden_dim = int(expansion_rate * dim_out)
stride = 2 if downsample else 1
net = nn.Sequential(
nn.Conv2d(dim_in, hidden_dim, 1),
# nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
),
# nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
# nn.BatchNorm2d(dim_out)
)
if dim_in == dim_out and not downsample:
net = MBConvResidual(net, dropout=dropout)
return net
# attention related classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=32,
dropout=0.0,
window_size=7,
with_pe=True,
):
super().__init__()
assert (
dim % dim_head
) == 0, "dimension should be divisible by dimension per head"
self.heads = dim // dim_head
self.scale = dim_head**-0.5
self.with_pe = with_pe
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
self.to_out = nn.Sequential(
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
)
# relative positional bias
if self.with_pe:
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos))
grid = rearrange(grid, "c i j -> (i j) c")
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
grid, "j ... -> 1 j ..."
)
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
dim=-1
)
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = (
*x.shape,
x.device,
self.heads,
)
# flatten
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
# project for queries, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# split heads
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
# scale
q = q * self.scale
# sim
sim = einsum("b h i d, b h j d -> b h i j", q, k)
# add positional bias
if self.with_pe:
bias = self.rel_pos_bias(self.rel_pos_indices)
sim = sim + rearrange(bias, "i j h -> h i j")
# attention
attn = self.attend(sim)
# aggregate
out = einsum("b h i j, b h j d -> b h i d", attn, v)
# merge heads
out = rearrange(
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
)
# combine heads out
out = self.to_out(out)
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
class Block_Attention(nn.Module):
def __init__(
self,
dim,
dim_head=32,
bias=False,
dropout=0.0,
window_size=7,
with_pe=True,
):
super().__init__()
assert (
dim % dim_head
) == 0, "dimension should be divisible by dimension per head"
self.heads = dim // dim_head
self.ps = window_size
self.scale = dim_head**-0.5
self.with_pe = with_pe
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
# project for queries, keys, values
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
# split heads
q, k, v = map(
lambda t: rearrange(
t,
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
h=self.heads,
w1=self.ps,
w2=self.ps,
),
(q, k, v),
)
# scale
q = q * self.scale
# sim
sim = einsum("b h i d, b h j d -> b h i j", q, k)
# attention
attn = self.attend(sim)
# aggregate
out = einsum("b h i j, b h j d -> b h i d", attn, v)
# merge heads
out = rearrange(
out,
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
x=h // self.ps,
y=w // self.ps,
head=self.heads,
w1=self.ps,
w2=self.ps,
)
out = self.to_out(out)
return out
class Channel_Attention(nn.Module):
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
super(Channel_Attention, self).__init__()
self.heads = heads
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.ps = window_size
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
qkv = qkv.chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(
t,
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
ph=self.ps,
pw=self.ps,
head=self.heads,
),
qkv,
)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = attn @ v
out = rearrange(
out,
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
h=h // self.ps,
w=w // self.ps,
ph=self.ps,
pw=self.ps,
head=self.heads,
)
out = self.project_out(out)
return out
class Channel_Attention_grid(nn.Module):
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
super(Channel_Attention_grid, self).__init__()
self.heads = heads
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.ps = window_size
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3,
dim * 3,
kernel_size=3,
stride=1,
padding=1,
groups=dim * 3,
bias=bias,
)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
qkv = qkv.chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(
t,
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
ph=self.ps,
pw=self.ps,
head=self.heads,
),
qkv,
)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = attn @ v
out = rearrange(
out,
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
h=h // self.ps,
w=w // self.ps,
ph=self.ps,
pw=self.ps,
head=self.heads,
)
out = self.project_out(out)
return out
class OSA_Block(nn.Module):
def __init__(
self,
channel_num=64,
bias=True,
ffn_bias=True,
window_size=8,
with_pe=False,
dropout=0.0,
):
super(OSA_Block, self).__init__()
w = window_size
self.layer = nn.Sequential(
MBConv(
channel_num,
channel_num,
downsample=False,
expansion_rate=1,
shrinkage_rate=0.25,
),
Rearrange(
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
), # block-like attention
PreNormResidual(
channel_num,
Attention(
dim=channel_num,
dim_head=channel_num // 4,
dropout=dropout,
window_size=window_size,
with_pe=with_pe,
),
),
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
# channel-like attention
Conv_PreNormResidual(
channel_num,
Channel_Attention(
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
),
),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
Rearrange(
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
), # grid-like attention
PreNormResidual(
channel_num,
Attention(
dim=channel_num,
dim_head=channel_num // 4,
dropout=dropout,
window_size=window_size,
with_pe=with_pe,
),
),
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
# channel-like attention
Conv_PreNormResidual(
channel_num,
Channel_Attention_grid(
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
),
),
Conv_PreNormResidual(
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
),
)
def forward(self, x):
out = self.layer(x)
return out

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: OSAG.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch.nn as nn
from .esa import ESA
from .OSA import OSA_Block
class OSAG(nn.Module):
def __init__(
self,
channel_num=64,
bias=True,
block_num=4,
ffn_bias=False,
window_size=0,
pe=False,
):
super(OSAG, self).__init__()
# print("window_size: %d" % (window_size))
# print("with_pe", pe)
# print("ffn_bias: %d" % (ffn_bias))
# block_script_name = kwargs.get("block_script_name", "OSA")
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
# script_name = "." + block_script_name
# package = __import__(script_name, fromlist=True)
block_class = OSA_Block # getattr(package, block_class_name)
group_list = []
for _ in range(block_num):
temp_res = block_class(
channel_num,
bias,
ffn_bias=ffn_bias,
window_size=window_size,
with_pe=pe,
)
group_list.append(temp_res)
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
self.residual_layer = nn.Sequential(*group_list)
esa_channel = max(channel_num // 4, 16)
self.esa = ESA(esa_channel, channel_num)
def forward(self, x):
out = self.residual_layer(x)
out = out + x
return self.esa(out)

View File

@ -0,0 +1,133 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: OmniSR.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .OSAG import OSAG
from .pixelshuffle import pixelshuffle_block
class OmniSR(nn.Module):
def __init__(
self,
state_dict,
**kwargs,
):
super(OmniSR, self).__init__()
self.state = state_dict
bias = True # Fine to assume this for now
block_num = 1 # Fine to assume this for now
ffn_bias = True
pe = True
num_feat = state_dict["input.weight"].shape[0] or 64
num_in_ch = state_dict["input.weight"].shape[1] or 3
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
if up_scale - int(up_scale) > 0:
print(
"out_nc is probably different than in_nc, scale calculation might be wrong"
)
up_scale = int(up_scale)
res_num = 0
for key in state_dict.keys():
if "residual_layer" in key:
temp_res_num = int(key.split(".")[1])
if temp_res_num > res_num:
res_num = temp_res_num
res_num = res_num + 1 # zero-indexed
residual_layer = []
self.res_num = res_num
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
self.up_scale = up_scale
for _ in range(res_num):
temp_res = OSAG(
channel_num=num_feat,
bias=bias,
block_num=block_num,
ffn_bias=ffn_bias,
window_size=self.window_size,
pe=pe,
)
residual_layer.append(temp_res)
self.residual_layer = nn.Sequential(*residual_layer)
self.input = nn.Conv2d(
in_channels=num_in_ch,
out_channels=num_feat,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
self.output = nn.Conv2d(
in_channels=num_feat,
out_channels=num_feat,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
)
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, sqrt(2. / n))
# chaiNNer specific stuff
self.model_arch = "OmniSR"
self.sub_type = "SR"
self.in_nc = num_in_ch
self.out_nc = num_out_ch
self.num_feat = num_feat
self.scale = up_scale
self.supports_fp16 = True # TODO: Test this
self.supports_bfp16 = True
self.min_size_restriction = 16
self.load_state_dict(state_dict, strict=False)
def check_image_size(self, x):
_, _, h, w = x.size()
# import pdb; pdb.set_trace()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
residual = self.input(x)
out = self.residual_layer(residual)
# origin
out = torch.add(self.output(out), residual)
out = self.up(out)
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
return out

View File

@ -0,0 +1,294 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: esa.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 20th April 2023 9:28:06 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layernorm import LayerNorm2d
def moment(x, dim=(2, 3), k=2):
assert len(x.size()) == 4
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
return mk
class ESA(nn.Module):
"""
Modification of Enhanced Spatial Attention (ESA), which is proposed by
`Residual Feature Aggregation Network for Image Super-Resolution`
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
are deleted.
"""
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
super(ESA, self).__init__()
f = esa_channels
self.conv1 = conv(n_feats, f, kernel_size=1)
self.conv_f = conv(f, f, kernel_size=1)
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
self.conv3 = conv(f, f, kernel_size=3, padding=1)
self.conv4 = conv(f, n_feats, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
c1_ = self.conv1(x)
c1 = self.conv2(c1_)
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
c3 = self.conv3(v_max)
c3 = F.interpolate(
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
)
cf = self.conv_f(c1_)
c4 = self.conv4(c3 + cf)
m = self.sigmoid(c4)
return x * m
class LK_ESA(nn.Module):
def __init__(
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
):
super(LK_ESA, self).__init__()
f = esa_channels
self.conv1 = conv(n_feats, f, kernel_size=1)
self.conv_f = conv(f, f, kernel_size=1)
kernel_size = 17
kernel_expand = kernel_expand
padding = kernel_size // 2
self.vec_conv = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(1, kernel_size),
padding=(0, padding),
groups=2,
bias=bias,
)
self.vec_conv3x1 = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(1, 3),
padding=(0, 1),
groups=2,
bias=bias,
)
self.hor_conv = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(kernel_size, 1),
padding=(padding, 0),
groups=2,
bias=bias,
)
self.hor_conv1x3 = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(3, 1),
padding=(1, 0),
groups=2,
bias=bias,
)
self.conv4 = conv(f, n_feats, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
c1_ = self.conv1(x)
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
res = self.hor_conv(res) + self.hor_conv1x3(res)
cf = self.conv_f(c1_)
c4 = self.conv4(res + cf)
m = self.sigmoid(c4)
return x * m
class LK_ESA_LN(nn.Module):
def __init__(
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
):
super(LK_ESA_LN, self).__init__()
f = esa_channels
self.conv1 = conv(n_feats, f, kernel_size=1)
self.conv_f = conv(f, f, kernel_size=1)
kernel_size = 17
kernel_expand = kernel_expand
padding = kernel_size // 2
self.norm = LayerNorm2d(n_feats)
self.vec_conv = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(1, kernel_size),
padding=(0, padding),
groups=2,
bias=bias,
)
self.vec_conv3x1 = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(1, 3),
padding=(0, 1),
groups=2,
bias=bias,
)
self.hor_conv = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(kernel_size, 1),
padding=(padding, 0),
groups=2,
bias=bias,
)
self.hor_conv1x3 = nn.Conv2d(
in_channels=f * kernel_expand,
out_channels=f * kernel_expand,
kernel_size=(3, 1),
padding=(1, 0),
groups=2,
bias=bias,
)
self.conv4 = conv(f, n_feats, kernel_size=1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
c1_ = self.norm(x)
c1_ = self.conv1(c1_)
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
res = self.hor_conv(res) + self.hor_conv1x3(res)
cf = self.conv_f(c1_)
c4 = self.conv4(res + cf)
m = self.sigmoid(c4)
return x * m
class AdaGuidedFilter(nn.Module):
def __init__(
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
):
super(AdaGuidedFilter, self).__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(
in_channels=n_feats,
out_channels=1,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)
self.r = 5
def box_filter(self, x, r):
channel = x.shape[1]
kernel_size = 2 * r + 1
weight = 1.0 / (kernel_size**2)
box_kernel = weight * torch.ones(
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
)
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
return output
def forward(self, x):
_, _, H, W = x.shape
N = self.box_filter(
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
)
# epsilon = self.fc(self.gap(x))
# epsilon = torch.pow(epsilon, 2)
epsilon = 1e-2
mean_x = self.box_filter(x, self.r) / N
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
A = var_x / (var_x + epsilon)
b = (1 - A) * mean_x
m = A * x + b
# mean_A = self.box_filter(A, self.r) / N
# mean_b = self.box_filter(b, self.r) / N
# m = mean_A * x + mean_b
return x * m
class AdaConvGuidedFilter(nn.Module):
def __init__(
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
):
super(AdaConvGuidedFilter, self).__init__()
f = esa_channels
self.conv_f = conv(f, f, kernel_size=1)
kernel_size = 17
kernel_expand = kernel_expand
padding = kernel_size // 2
self.vec_conv = nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=(1, kernel_size),
padding=(0, padding),
groups=f,
bias=bias,
)
self.hor_conv = nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=(kernel_size, 1),
padding=(padding, 0),
groups=f,
bias=bias,
)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias=True,
)
def forward(self, x):
y = self.vec_conv(x)
y = self.hor_conv(y)
sigma = torch.pow(y, 2)
epsilon = self.fc(self.gap(y))
weight = sigma / (sigma + epsilon)
m = weight * x + (1 - weight)
return x * m

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: layernorm.py
# Created Date: Tuesday April 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 20th April 2023 9:28:20 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return (
gx,
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
None,
)
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
class GRN(nn.Module):
"""GRN (Global Response Normalization) layer"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: pixelshuffle.py
# Created Date: Friday July 1st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 1st July 2022 10:18:39 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch.nn as nn
def pixelshuffle_block(
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
):
"""
Upsample features according to `upscale_factor`.
"""
padding = kernel_size // 2
conv = nn.Conv2d(
in_channels,
out_channels * (upscale_factor**2),
kernel_size,
padding=1,
bias=bias,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
return nn.Sequential(*[conv, pixel_shuffle])

View File

@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
self.scale: int = self.get_scale()
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
c2x2 = False
if self.state["model.0.weight"].shape[-2] == 2:
c2x2 = True
self.scale = round(math.sqrt(self.scale / 4))
self.model_arch = "ESRGAN-2c2"
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
out_nc=self.num_filters,
upscale_factor=3,
act_type=self.act,
c2x2=c2x2,
)
else:
upsample_blocks = [
upsample_block(
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
in_nc=self.num_filters,
out_nc=self.num_filters,
act_type=self.act,
c2x2=c2x2,
)
for _ in range(int(math.log(self.scale, 2)))
]
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
B.ShortcutBlock(
B.sequential(
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
act_type=self.act,
mode="CNA",
plus=self.plus,
c2x2=c2x2,
)
for _ in range(self.num_blocks)
],
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
norm_type=self.norm,
act_type=None,
mode=self.mode,
c2x2=c2x2,
),
)
),
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=self.act,
c2x2=c2x2,
),
# hr_conv1
B.conv_block(
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
)

View File

@ -141,6 +141,19 @@ def sequential(*args):
ConvMode = Literal["CNA", "NAC", "CNAC"]
# 2x2x2 Conv Block
def conv_block_2c2(
in_nc,
out_nc,
act_type="relu",
):
return sequential(
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
act(act_type) if act_type else None,
)
def conv_block(
in_nc: int,
out_nc: int,
@ -153,12 +166,17 @@ def conv_block(
norm_type: str | None = None,
act_type: str | None = "relu",
mode: ConvMode = "CNA",
c2x2=False,
):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
if c2x2:
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
@ -285,6 +303,7 @@ class RRDB(nn.Module):
_convtype="Conv2D",
_spectral_norm=False,
plus=False,
c2x2=False,
):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(
@ -298,6 +317,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB2 = ResidualDenseBlock_5C(
nf,
@ -310,6 +330,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB3 = ResidualDenseBlock_5C(
nf,
@ -322,6 +343,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
def forward(self, x):
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
act_type="leakyrelu",
mode: ConvMode = "CNA",
plus=False,
c2x2=False,
):
super(ResidualDenseBlock_5C, self).__init__()
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv2 = conv_block(
nf + gc,
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv3 = conv_block(
nf + 2 * gc,
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv4 = conv_block(
nf + 3 * gc,
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
if mode == "CNA":
last_act = None
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=last_act,
mode=mode,
c2x2=c2x2,
)
def forward(self, x):
@ -499,6 +527,7 @@ def upconv_block(
norm_type: str | None = None,
act_type="relu",
mode="nearest",
c2x2=False,
):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
@ -512,5 +541,6 @@ def upconv_block(
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
c2x2=c2x2,
)
return sequential(upsample, conv)

View File

@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
from .architecture.HAT import HAT
from .architecture.LaMa import LaMa
from .architecture.MAT import MAT
from .architecture.OmniSR.OmniSR import OmniSR
from .architecture.RRDB import RRDBNet as ESRGAN
from .architecture.SPSR import SPSRNet as SPSR
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
state_dict = state_dict["params"]
state_dict_keys = list(state_dict.keys())
# SRVGGNet Real-ESRGAN (v2)
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
model = RealESRGANv2(state_dict)
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
# MAT
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
model = MAT(state_dict)
# Omni-SR
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
model = OmniSR(state_dict)
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
else:
try:

View File

@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
from .architecture.HAT import HAT
from .architecture.LaMa import LaMa
from .architecture.MAT import MAT
from .architecture.OmniSR.OmniSR import OmniSR
from .architecture.RRDB import RRDBNet as ESRGAN
from .architecture.SPSR import SPSRNet as SPSR
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
from .architecture.Swin2SR import Swin2SR
from .architecture.SwinIR import SwinIR
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
PyTorchSRModel = Union[
RealESRGANv2,
SPSR,
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
SwinIR,
Swin2SR,
HAT,
OmniSR,
]

View File

@ -72,7 +72,7 @@ class MaskToImage:
FUNCTION = "mask_to_image"
def mask_to_image(self, mask):
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,)
class ImageToMask:
@ -167,7 +167,7 @@ class MaskComposite:
"source": ("MASK",),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"operation": (["multiply", "add", "subtract"],),
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
}
}
@ -193,6 +193,12 @@ class MaskComposite:
output[top:bottom, left:right] = destination_portion + source_portion
elif operation == "subtract":
output[top:bottom, left:right] = destination_portion - source_portion
elif operation == "and":
output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "or":
output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
elif operation == "xor":
output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
output = torch.clamp(output, 0.0, 1.0)

View File

@ -59,6 +59,12 @@ class Blend:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
class Blur:
def __init__(self):
pass
@ -88,12 +94,6 @@ class Blur:
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
@ -101,10 +101,11 @@ class Blur:
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
@ -167,9 +168,15 @@ class Sharpen:
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.1
}),
@ -181,21 +188,21 @@ class Sharpen:
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -17,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
out = model_loading.load_state_dict(sd).eval()
return (out, )

View File

@ -6,6 +6,7 @@ import threading
import heapq
import traceback
import gc
import time
import torch
import nodes
@ -26,27 +27,96 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
input_data_all[x] = [unique_id]
return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def format_value(x):
if x is None:
return None
elif isinstance(x, (int, float, bool, str)):
return x
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return
return (True, None, None)
for x in inputs:
input_data = inputs[x]
@ -55,23 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
if result[0] is not True:
# Another node failed further upstream
return result
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id)
obj = class_def()
nodes.before_node_execution()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
if "ui" in outputs[unique_id]:
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
}
return (False, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
input_data_formatted = {}
if input_data_all is not None:
input_data_formatted = {}
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!")
print(traceback.format_exc())
error_details = {
"node_id": unique_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
return (False, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
@ -105,7 +216,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
is_changed = class_def.IS_CHANGED(**input_data_all)
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
@ -144,10 +256,53 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
self.outputs_ui = {}
self.old_prompt = {}
self.server = server
def execute(self, prompt, extra_data={}):
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
class_type = prompt[node_id]["class_type"]
# First, send back the status to the frontend depending
# on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
}
self.server.send_sync("execution_interrupted", mes, self.server.client_id)
else:
if self.server.client_id is not None:
mes = {
"prompt_id": prompt_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
"current_inputs": error["current_inputs"],
"current_outputs": error["current_outputs"],
}
self.server.send_sync("execution_error", mes, self.server.client_id)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@ -155,6 +310,10 @@ class PromptExecutor:
else:
self.server.client_id = None
execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
@ -169,105 +328,250 @@ class PromptExecutor:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
executed = set()
try:
to_execute = []
for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
to_execute += [(0, x)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
if class_.OUTPUT_NODE == True:
valid = False
try:
m = validate_inputs(prompt, x)
valid = m[0]
except:
valid = False
if valid:
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
except Exception as e:
print(traceback.format_exc())
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
finally:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id)
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set()
output_node_id = None
to_execute = []
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item):
def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
errors = []
valid = True
for x in required_inputs:
if x not in inputs:
return (False, "Required input is missing. {}, {}".format(class_type, x))
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
if isinstance(val, list):
if len(val) != 2:
return (False, "Bad Input. {}, {}".format(class_type, x))
error = {
"type": "bad_linked_input",
"message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val
}
}
errors.append(error)
continue
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
r = validate_inputs(prompt, o_id)
if r[0] == False:
return r
received_type = r[val[1]]
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_type": received_type,
"linked_node": val
}
}
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
continue
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_inner_validation",
"message": "Exception when validating inner node",
"details": str(ex),
"extra_info": {
"input_name": x,
"input_config": info,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"linked_node": val
}
}]
validated[o_id] = (False, reasons, o_id)
continue
else:
if type_input == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
val = str(val)
inputs[x] = val
try:
if type_input == "INT":
val = int(val)
inputs[x] = val
if type_input == "FLOAT":
val = float(val)
inputs[x] = val
if type_input == "STRING":
val = str(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value",
"details": f"{x}, {val}, {ex}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
"exception_message": str(ex)
}
}
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
return (False, "Value smaller than min. {}, {}".format(class_type, x))
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x))
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")
input_config = info
list_info = ""
# Don't send back gigantic lists like if they're lots of
# scanned model filepaths
if len(type_input) > 20:
list_info = f"(list of length {len(type_input)})"
input_config = None
else:
list_info = str(type_input)
error = {
"type": "value_not_in_list",
"message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}",
"extra_info": {
"input_name": x,
"input_config": input_config,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
ret = (True, [], unique_id)
validated[unique_id] = ret
return ret
def full_type_name(klass):
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt):
outputs = set()
@ -277,34 +581,86 @@ def validate_prompt(prompt):
outputs.add(x)
if len(outputs) == 0:
return (False, "Prompt has no outputs")
error = {
"type": "prompt_no_outputs",
"message": "Prompt has no outputs",
"details": "",
"extra_info": {}
}
return (False, error, [], [])
good_outputs = set()
errors = []
node_errors = {}
validated = {}
for o in outputs:
valid = False
reason = ""
reasons = []
try:
m = validate_inputs(prompt, o)
m = validate_inputs(prompt, o, validated)
valid = m[0]
reason = m[1]
except Exception as e:
print(traceback.format_exc())
reasons = m[1]
except Exception as ex:
typ, _, tb = sys.exc_info()
valid = False
reason = "Parsing error"
exception_type = full_type_name(typ)
reasons = [{
"type": "exception_during_validation",
"message": "Exception when validating node",
"details": str(ex),
"extra_info": {
"exception_type": exception_type,
"traceback": traceback.format_tb(tb)
}
}]
validated[o] = (False, reasons, o)
if valid == True:
good_outputs.add(x)
if valid is True:
good_outputs.add(o)
else:
print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored")
errors += [(o, reason)]
print(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0:
print("* (prompt):")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
reasons = result[1]
# If a node upstream has errors, the nodes downstream will also
# be reported as invalid, but there will be no errors attached.
# So don't return those nodes as having errors in the response.
if valid is not True and len(reasons) > 0:
if node_id not in node_errors:
class_type = prompt[node_id]['class_type']
node_errors[node_id] = {
"errors": reasons,
"dependent_outputs": [],
"class_type": class_type
}
print(f"* {class_type} {node_id}:")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored")
if len(good_outputs) == 0:
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
errors_list = []
for o, errors in errors:
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list)
return (True, "")
error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
}
return (False, error, list(good_outputs), node_errors)
return (True, None, list(good_outputs), node_errors)
class PromptQueue:
@ -340,8 +696,7 @@ class PromptQueue:
prompt = self.currently_running.pop(item_id)
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
if "ui" in outputs[o]:
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
self.history[prompt[1]]["outputs"][o] = outputs[o]
self.server.queue_updated()
def get_current_queue(self):

View File

@ -1,14 +1,8 @@
import os
import time
supported_ckpt_extensions = set(['.ckpt', '.pth'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth'])
try:
import safetensors.torch
supported_ckpt_extensions.add('.safetensors')
supported_pt_extensions.add('.safetensors')
except:
print("Could not import safetensors, safetensors support disabled.")
supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors'])
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
folder_names_and_paths = {}
@ -38,6 +32,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
filename_list_cache = {}
if not os.path.exists(input_directory):
os.makedirs(input_directory)
@ -118,12 +114,18 @@ def get_folder_paths(folder_name):
return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory):
if not os.path.isdir(directory):
return [], {}
result = []
dirs = {directory: os.path.getmtime(directory)}
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
for d in subdir:
path = os.path.join(root, d)
dirs[path] = os.path.getmtime(path)
return result, dirs
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
@ -132,19 +134,90 @@ def filter_files_extensions(files, extensions):
def get_full_path(folder_name, filename):
global folder_names_and_paths
if folder_name not in folder_names_and_paths:
return None
folders = folder_names_and_paths[folder_name]
filename = os.path.relpath(os.path.join("/", filename), "/")
for x in folders[0]:
full_path = os.path.join(x, filename)
if os.path.isfile(full_path):
return full_path
return None
def get_filename_list(folder_name):
def get_filename_list_(folder_name):
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
output_folders = {}
for x in folders[0]:
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list))
files, folders_all = recursive_search(x)
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders, time.perf_counter())
def cached_filename_list_(folder_name):
global filename_list_cache
global folder_names_and_paths
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
if x not in out[1]:
return None
return out
def get_filename_list(folder_name):
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input, image_width, image_height):
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix

View File

@ -33,8 +33,8 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs)
e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

270
nodes.py
View File

@ -6,16 +6,18 @@ import json
import hashlib
import traceback
import math
import time
from PIL import Image
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_convert
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
import comfy.sd
@ -28,6 +30,7 @@ import importlib
import folder_paths
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@ -145,9 +148,6 @@ class ConditioningSetMask:
return (c, )
class VAEDecode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -160,9 +160,6 @@ class VAEDecode:
return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -175,9 +172,6 @@ class VAEDecodeTiled:
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -202,9 +196,6 @@ class VAEEncode:
return ({"samples":t}, )
class VAEEncodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -219,9 +210,6 @@ class VAEEncodeTiled:
return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# support save metadata for latent sharing
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
safetensors.torch.save_file(output, file, metadata=metadata)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
if not folder_paths.exists_annotated_filepath(latent):
return "Invalid latent file: {}".format(latent)
return True
class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
@ -296,7 +359,10 @@ class DiffusersLoader:
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
for root, subdir, files in os.walk(search_path, followlinks=True):
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -306,12 +372,12 @@ class DiffusersLoader:
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
path = os.path.join(search_path, model_path)
if os.path.exists(path):
model_path = path
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader:
@ -360,6 +426,9 @@ class LoraLoader:
CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
return (model_lora, clip_lora)
@ -517,9 +586,11 @@ class ControlNetApply:
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image, strength):
if strength == 0:
return (conditioning, )
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
c_net = control_net.copy().set_cond_hint(control_hint, strength)
@ -624,6 +695,9 @@ class unCLIPConditioning:
CATEGORY = "conditioning"
def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
if strength == 0:
return (conditioning, )
c = []
for t in conditioning:
o = t[1].copy()
@ -706,22 +780,61 @@ class LatentFromBatch:
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
FUNCTION = "frombatch"
CATEGORY = "latent"
CATEGORY = "latent/batch"
def rotate(self, samples, batch_index):
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone()
s["batch_index"] = batch_index
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,)
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
crop_methods = ["disabled", "center"]
@classmethod
@ -740,6 +853,25 @@ class LatentUpscale:
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentUpscaleBy:
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, samples, upscale_method, scale_by):
s = samples.copy()
width = round(samples["samples"].shape[3] * scale_by)
height = round(samples["samples"].shape[2] * scale_by)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
return (s,)
class LatentRotate:
@classmethod
def INPUT_TYPES(s):
@ -872,7 +1004,7 @@ class SetLatentNoiseMask:
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
@ -882,8 +1014,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
skip = latent["batch_index"] if "batch_index" in latent else 0
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@ -978,39 +1110,7 @@ class SaveImage:
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input):
input = input.replace("%width%", str(images[0].shape[1]))
input = input.replace("%height%", str(images[0].shape[0]))
return input
filename_prefix = compute_vars(filename_prefix)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(self.output_dir, subfolder)
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -1049,8 +1149,9 @@ class LoadImage:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(os.listdir(input_dir)), )},
{"image": (sorted(files), )},
}
CATEGORY = "image"
@ -1060,6 +1161,7 @@ class LoadImage:
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
@ -1090,9 +1192,10 @@ class LoadImageMask:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(os.listdir(input_dir)), ),
"channel": (s._color_channels, ),}
{"image": (sorted(files), ),
"channel": (s._color_channels, ), }
}
CATEGORY = "mask"
@ -1102,6 +1205,7 @@ class LoadImageMask:
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA")
mask = None
@ -1244,7 +1348,9 @@ NODE_CLASS_MAPPINGS = {
"VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"LatentUpscaleBy": LatentUpscaleBy,
"LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
@ -1282,6 +1388,9 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1319,7 +1428,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentCrop": "Crop Latent",
"EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent",
"LatentUpscaleBy": "Upscale Latent By",
"LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
@ -1351,14 +1463,18 @@ def load_custom_node(module_path):
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules:
@ -1367,11 +1483,25 @@ def load_custom_nodes():
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
if module_path.endswith(".disabled"): continue
time_before = time.perf_counter()
success = load_custom_node(module_path)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_nodes()

View File

@ -175,6 +175,8 @@
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
@ -183,7 +185,9 @@
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
"\n",
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
" for line in p.stdout:\n",
" print(line.decode(), end='')\n",

250
server.py
View File

@ -7,6 +7,9 @@ import execution
import uuid
import json
import glob
from PIL import Image
from io import BytesIO
try:
import aiohttp
from aiohttp import web
@ -19,7 +22,8 @@ except ImportError:
import mimetypes
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
@web.middleware
async def cache_control(request: web.Request, handler):
@ -78,7 +82,7 @@ class PromptServer():
# Reusing existing session, remove old
self.sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
sid = uuid.uuid4().hex
self.sockets[sid] = ws
@ -110,49 +114,96 @@ class PromptServer():
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
@routes.post("/upload/image")
async def upload_image(request):
post = await request.post()
def get_dir_by_type(dir_type):
if dir_type is None:
dir_type = "input"
if dir_type == "input":
type_dir = folder_paths.get_input_directory()
elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
return type_dir, dir_type
def image_upload(post, image_save_function=None):
image = post.get("image")
overwrite = post.get("overwrite")
if post.get("type") is None:
upload_dir = folder_paths.get_input_directory()
elif post.get("type") == "input":
upload_dir = folder_paths.get_input_directory()
elif post.get("type") == "temp":
upload_dir = folder_paths.get_temp_directory()
elif post.get("type") == "output":
upload_dir = folder_paths.get_output_directory()
if not os.path.exists(upload_dir):
os.makedirs(upload_dir)
image_upload_type = post.get("type")
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file:
filename = image.filename
if not filename:
return web.Response(status=400)
subfolder = post.get("subfolder", "")
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
return web.Response(status=400)
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder)
split = os.path.splitext(filename)
i = 1
while os.path.exists(os.path.join(upload_dir, filename)):
filename = f"{split[0]} ({i}){split[1]}"
i += 1
filepath = os.path.join(full_output_folder, filename)
filepath = os.path.join(upload_dir, filename)
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
pass
else:
i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
filepath = os.path.join(full_output_folder, filename)
i += 1
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename})
if image_save_function is not None:
image_save_function(image, post, filepath)
else:
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
else:
return web.Response(status=400)
@routes.post("/upload/image")
async def upload_image(request):
post = await request.post()
return image_upload(post)
@routes.post("/upload/mask")
async def upload_mask(request):
post = await request.post()
def image_save_function(image, post, filepath):
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
mask_pil = Image.open(image.file).convert('RGBA')
# alpha copy
new_alpha = mask_pil.getchannel('A')
original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4)
return image_upload(post, image_save_function)
@routes.get("/view")
async def view_image(request):
if "filename" in request.rel_url.query:
type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path
if filename[0] == '/' or '..' in filename:
return web.Response(status=400)
if output_dir is None:
type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
return web.Response(status=400)
@ -162,35 +213,132 @@ class PromptServer():
return web.Response(status=403)
output_dir = full_output_dir
filename = request.rel_url.query["filename"]
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
if 'channel' not in request.rel_url.query:
channel = 'rgba'
else:
channel = request.rel_url.query["channel"]
if channel == 'rgb':
with Image.open(file) as img:
if img.mode == "RGBA":
r, g, b, a = img.split()
new_img = Image.merge('RGB', (r, g, b))
else:
new_img = img.convert("RGB")
buffer = BytesIO()
new_img.save(buffer, format='PNG')
buffer.seek(0)
return web.Response(body=buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
elif channel == 'a':
with Image.open(file) as img:
if img.mode == "RGBA":
_, _, _, a = img.split()
else:
a = Image.new('L', img.size, 255)
# alpha img
alpha_img = Image.new('RGBA', img.size)
alpha_img.putalpha(a)
alpha_buffer = BytesIO()
alpha_img.save(alpha_buffer, format='PNG')
alpha_buffer.seek(0)
return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""})
else:
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
return web.Response(status=404)
@routes.get("/view_metadata/{folder_name}")
async def view_metadata(request):
folder_name = request.match_info.get("folder_name", None)
if folder_name is None:
return web.Response(status=404)
if not "filename" in request.rel_url.query:
return web.Response(status=404)
filename = request.rel_url.query["filename"]
if not filename.endswith(".safetensors"):
return web.Response(status=404)
safetensors_path = folder_paths.get_full_path(folder_name, filename)
if safetensors_path is None:
return web.Response(status=404)
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
if out is None:
return web.Response(status=404)
dt = json.loads(out)
if not "__metadata__" in dt:
return web.Response(status=404)
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt")
async def get_prompt(request):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
out[x] = node_info(x)
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/history")
@ -232,14 +380,16 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id, "number": number})
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
@ -249,9 +399,9 @@ class PromptServer():
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
delete_func = lambda a: a[1] == id_to_delete
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
@routes.post("/interrupt")
@ -275,7 +425,7 @@ class PromptServer():
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root),
web.static('/', self.web_root, follow_symlinks=True),
])
def get_queue_info(self):

View File

@ -0,0 +1,166 @@
import { app } from "/scripts/app.js";
import { ComfyDialog, $el } from "/scripts/ui.js";
import { ComfyApp } from "/scripts/app.js";
export class ClipspaceDialog extends ComfyDialog {
static items = [];
static instance = null;
static registerButton(name, contextPredicate, callback) {
const item =
$el("button", {
type: "button",
textContent: name,
contextPredicate: contextPredicate,
onclick: callback
})
ClipspaceDialog.items.push(item);
}
static invalidatePreview() {
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
const img_preview = document.getElementById("clipspace_preview");
if(img_preview) {
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
img_preview.style.maxHeight = "100%";
img_preview.style.maxWidth = "100%";
}
}
}
static invalidate() {
if(ClipspaceDialog.instance) {
const self = ClipspaceDialog.instance;
// allow reconstruct controls when copying from non-image to image content.
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
if(self.element) {
// update
self.element.removeChild(self.element.firstChild);
self.element.appendChild(children);
}
else {
// new
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
}
if(self.element.children[0].children.length <= 1) {
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
}
ClipspaceDialog.invalidatePreview();
}
}
constructor() {
super();
}
createButtons(self) {
const buttons = [];
for(let idx in ClipspaceDialog.items) {
const item = ClipspaceDialog.items[idx];
if(!item.contextPredicate || item.contextPredicate())
buttons.push(ClipspaceDialog.items[idx]);
}
buttons.push(
$el("button", {
type: "button",
textContent: "Close",
onclick: () => { this.close(); }
})
);
return buttons;
}
createImgSettings() {
if(ComfyApp.clipspace.imgs) {
const combo_items = [];
const imgs = ComfyApp.clipspace.imgs;
for(let i=0; i < imgs.length; i++) {
combo_items.push($el("option", {value:i}, [`${i}`]));
}
const combo1 = $el("select",
{id:"clipspace_img_selector", onchange:(event) => {
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
ClipspaceDialog.invalidatePreview();
} }, combo_items);
const row1 =
$el("tr", {},
[
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
$el("td", {}, [combo1])
]);
const combo2 = $el("select",
{id:"clipspace_img_paste_mode", onchange:(event) => {
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
} },
[
$el("option", {value:'selected'}, 'selected'),
$el("option", {value:'all'}, 'all')
]);
combo2.value = ComfyApp.clipspace['img_paste_mode'];
const row2 =
$el("tr", {},
[
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
$el("td", {}, [combo2])
]);
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
const row3 =
$el("tr", {}, [td]);
return $el("table", {}, [row1, row2, row3]);
}
else {
return [];
}
}
createImgPreview() {
if(ComfyApp.clipspace.imgs) {
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
}
else
return [];
}
show() {
const img_preview = document.getElementById("clipspace_preview");
ClipspaceDialog.invalidate();
this.element.style.display = "block";
}
}
app.registerExtension({
name: "Comfy.Clipspace",
init(app) {
app.openClipspace =
function () {
if(!ClipspaceDialog.instance) {
ClipspaceDialog.instance = new ClipspaceDialog(app);
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
}
if(ComfyApp.clipspace) {
ClipspaceDialog.instance.show();
}
else
app.ui.dialog.show("Clipspace is Empty!");
};
}
});

View File

@ -174,7 +174,7 @@ const els = {}
// const ctxMenu = LiteGraph.ContextMenu;
app.registerExtension({
name: id,
init() {
addCustomNodeDefs(node_defs) {
const sortObjectKeys = (unordered) => {
return Object.keys(unordered).sort().reduce((obj, key) => {
obj[key] = unordered[key];
@ -182,10 +182,10 @@ app.registerExtension({
}, {});
};
const getSlotTypes = async () => {
function getSlotTypes() {
var types = [];
const defs = await api.getNodeDefs();
const defs = node_defs;
for (const nodeId in defs) {
const nodeData = defs[nodeId];
@ -212,8 +212,8 @@ app.registerExtension({
return types;
};
const completeColorPalette = async (colorPalette) => {
var types = await getSlotTypes();
function completeColorPalette(colorPalette) {
var types = getSlotTypes();
for (const type of types) {
if (!colorPalette.colors.node_slot[type]) {

View File

@ -0,0 +1,648 @@
import { app } from "/scripts/app.js";
import { ComfyDialog, $el } from "/scripts/ui.js";
import { ComfyApp } from "/scripts/app.js";
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
// Helper function to convert a data URL to a Blob object
function dataURLToBlob(dataURL) {
const parts = dataURL.split(';base64,');
const contentType = parts[0].split(':')[1];
const byteString = atob(parts[1]);
const arrayBuffer = new ArrayBuffer(byteString.length);
const uint8Array = new Uint8Array(arrayBuffer);
for (let i = 0; i < byteString.length; i++) {
uint8Array[i] = byteString.charCodeAt(i);
}
return new Blob([arrayBuffer], { type: contentType });
}
function loadedImageToBlob(image) {
const canvas = document.createElement('canvas');
canvas.width = image.width;
canvas.height = image.height;
const ctx = canvas.getContext('2d');
ctx.drawImage(image, 0, 0);
const dataURL = canvas.toDataURL('image/png', 1);
const blob = dataURLToBlob(dataURL);
return blob;
}
async function uploadMask(filepath, formData) {
await fetch('/upload/mask', {
method: 'POST',
body: formData
}).then(response => {}).catch(error => {
console.error('Error:', error);
});
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
ClipspaceDialog.invalidatePreview();
}
function prepareRGB(image, backupCanvas, backupCtx) {
// paste mask data into alpha channel
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
if(backupData.data[i+3] == 255)
backupData.data[i+3] = 0;
else
backupData.data[i+3] = 255;
backupData.data[i] = 0;
backupData.data[i+1] = 0;
backupData.data[i+2] = 0;
}
backupCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0);
}
class MaskEditorDialog extends ComfyDialog {
static instance = null;
static getInstance() {
if(!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog(app);
}
return MaskEditorDialog.instance;
}
is_layout_created = false;
constructor() {
super();
this.element = $el("div.comfy-modal", { parent: document.body },
[ $el("div.comfy-modal-content",
[...this.createButtons()]),
]);
}
createButtons() {
return [];
}
createButton(name, callback) {
var button = document.createElement("button");
button.innerText = name;
button.addEventListener("click", callback);
return button;
}
createLeftButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "left";
button.style.marginRight = "4px";
return button;
}
createRightButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "right";
button.style.marginLeft = "4px";
return button;
}
createLeftSlider(self, name, callback) {
const divElement = document.createElement('div');
divElement.id = "maskeditor-slider";
divElement.style.cssFloat = "left";
divElement.style.fontFamily = "sans-serif";
divElement.style.marginRight = "4px";
divElement.style.color = "var(--input-text)";
divElement.style.backgroundColor = "var(--comfy-input-bg)";
divElement.style.borderRadius = "8px";
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "21px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
divElement.style.top = "2px";
self.brush_slider_input = document.createElement('input');
self.brush_slider_input.setAttribute('type', 'range');
self.brush_slider_input.setAttribute('min', '1');
self.brush_slider_input.setAttribute('max', '100');
self.brush_slider_input.setAttribute('value', '10');
const labelElement = document.createElement("label");
labelElement.textContent = name;
divElement.appendChild(labelElement);
divElement.appendChild(self.brush_slider_input);
self.brush_slider_input.addEventListener("change", callback);
return divElement;
}
setlayout(imgCanvas, maskCanvas) {
const self = this;
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
var placeholder = document.createElement("div");
placeholder.style.position = "relative";
placeholder.style.height = "50px";
var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px";
bottom_panel.style.left = "20px";
bottom_panel.style.right = "20px";
bottom_panel.style.height = "50px";
var brush = document.createElement("div");
brush.id = "brush";
brush.style.backgroundColor = "transparent";
brush.style.outline = "1px dashed black";
brush.style.boxShadow = "0 0 0 1px white";
brush.style.borderRadius = "50%";
brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute";
brush.style.zIndex = 8889;
brush.style.pointerEvents = "none";
this.brush = brush;
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
document.body.appendChild(brush);
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null);
});
var clearButton = this.createLeftButton("Clear",
() => {
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
});
var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.close();
});
this.saveButton = this.createRightButton("Save", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.save();
});
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider);
imgCanvas.style.position = "relative";
imgCanvas.style.top = "200";
imgCanvas.style.left = "0";
maskCanvas.style.position = "absolute";
}
show() {
if(!this.is_layout_created) {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas);
// prepare content
this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas;
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.setEventHandler(maskCanvas);
this.is_layout_created = true;
// replacement of onClose hook since close is not real close
const self = this;
const observer = new MutationObserver(function(mutations) {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
ComfyApp.onClipspaceEditorClosed();
}
self.last_display_style = self.element.style.display;
}
});
});
const config = { attributes: true };
observer.observe(this.element, config);
}
this.setImages(this.imgCanvas, this.backupCanvas);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
}
else {
this.saveButton.innerText = "Save";
}
this.saveButton.disabled = false;
this.element.style.display = "block";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
}
isOpened() {
return this.element.style.display == "block";
}
setImages(imgCanvas, backupCanvas) {
const imgCtx = imgCanvas.getContext('2d');
const backupCtx = backupCanvas.getContext('2d');
const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load
const orig_image = new Image();
window.addEventListener("resize", () => {
// repositioning
imgCanvas.width = window.innerWidth - 250;
imgCanvas.height = window.innerHeight - 200;
// redraw image
let drawWidth = orig_image.width;
let drawHeight = orig_image.height;
if (orig_image.width > imgCanvas.width) {
drawWidth = imgCanvas.width;
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
}
if (drawHeight > imgCanvas.height) {
drawHeight = imgCanvas.height;
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
}
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
});
const filepath = ComfyApp.clipspace.images;
const touched_image = new Image();
touched_image.onload = function() {
backupCanvas.width = touched_image.width;
backupCanvas.height = touched_image.height;
prepareRGB(touched_image, backupCanvas, backupCtx);
};
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel');
alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url;
// original image load
orig_image.onload = function() {
window.dispatchEvent(new Event('resize'));
};
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel');
rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url;
this.image = orig_image;
}
setEventHandler(maskCanvas) {
maskCanvas.addEventListener("contextmenu", (event) => {
event.preventDefault();
});
const self = this;
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
}
brush_size = 10;
drawing_mode = false;
lastx = -1;
lasty = -1;
lasttime = 0;
static handleKeyDown(event) {
const self = MaskEditorDialog.instance;
if (event.key === ']') {
self.brush_size = Math.min(self.brush_size+2, 100);
} else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1);
} else if(event.key === 'Enter') {
self.save();
}
self.updateBrushPreview(self);
}
static handlePointerUp(event) {
event.preventDefault();
MaskEditorDialog.instance.drawing_mode = false;
}
updateBrushPreview(self) {
const brush = self.brush;
var centerX = self.cursorX;
var centerY = self.cursorY;
brush.style.width = self.brush_size * 2 + "px";
brush.style.height = self.brush_size * 2 + "px";
brush.style.left = (centerX - self.brush_size) + "px";
brush.style.top = (centerY - self.brush_size) + "px";
}
handleWheelEvent(self, event) {
if(event.deltaY < 0)
self.brush_size = Math.min(self.brush_size+2, 100);
else
self.brush_size = Math.max(self.brush_size-2, 1);
self.brush_slider_input.value = self.brush_size;
self.updateBrushPreview(self);
}
draw_move(self, event) {
event.preventDefault();
this.cursorX = event.pageX;
this.cursorY = event.pageY;
self.updateBrushPreview(self);
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
var diff = performance.now() - self.lasttime;
const maskRect = self.maskCanvas.getBoundingClientRect();
var x = event.offsetX;
var y = event.offsetY
if(event.offsetX == null) {
x = event.targetTouches[0].clientX - maskRect.left;
}
if(event.offsetY == null) {
y = event.targetTouches[0].clientY - maskRect.top;
}
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
brush_size *= this.last_pressure;
}
else {
brush_size = this.brush_size;
}
if(diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
var dx = x - self.lastx;
var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
var directionX = dx / distance;
var directionY = dy / distance;
for (var i = 0; i < distance; i+=5) {
var px = self.lastx + (directionX * i);
var py = self.lasty + (directionY * i);
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
}
self.lastx = x;
self.lasty = y;
});
self.lasttime = performance.now();
}
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
brush_size *= this.last_pressure;
}
else {
brush_size = this.brush_size;
}
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
var dx = x - self.lastx;
var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
var directionX = dx / distance;
var directionY = dy / distance;
for (var i = 0; i < distance; i+=5) {
var px = self.lastx + (directionX * i);
var py = self.lasty + (directionY * i);
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
}
self.lastx = x;
self.lasty = y;
});
self.lasttime = performance.now();
}
}
handlePointerDown(self, event) {
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
this.last_pressure = event.pressure;
}
if ([0, 2, 5].includes(event.button)) {
self.drawing_mode = true;
event.preventDefault();
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
self.maskCtx.beginPath();
if (event.button == 0) {
self.maskCtx.fillStyle = "rgb(0,0,0)";
self.maskCtx.globalCompositeOperation = "source-over";
} else {
self.maskCtx.globalCompositeOperation = "destination-out";
}
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.lastx = x;
self.lasty = y;
self.lasttime = performance.now();
}
}
async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
backupCtx.drawImage(this.maskCanvas,
0, 0, this.maskCanvas.width, this.maskCanvas.height,
0, 0, this.backupCanvas.width, this.backupCanvas.height);
// paste mask data into alpha channel
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
if(backupData.data[i+3] == 255)
backupData.data[i+3] = 0;
else
backupData.data[i+3] = 255;
backupData.data[i] = 0;
backupData.data[i+1] = 0;
backupData.data[i+2] = 0;
}
backupCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0);
const formData = new FormData();
const filename = "clipspace-mask-" + performance.now() + ".png";
const item =
{
"filename": filename,
"subfolder": "clipspace",
"type": "input",
};
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[0] = item;
if(ComfyApp.clipspace.widgets) {
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0)
ComfyApp.clipspace.widgets[index].value = item;
}
const dataURL = this.backupCanvas.toDataURL();
const blob = dataURLToBlob(dataURL);
const original_blob = loadedImageToBlob(this.image);
formData.append('image', blob, filename);
formData.append('original_image', original_blob);
formData.append('type', "input");
formData.append('subfolder', "clipspace");
this.saveButton.innerText = "Saving...";
this.saveButton.disabled = true;
await uploadMask(item, formData);
ComfyApp.onClipspaceEditorSave();
this.close();
}
}
app.registerExtension({
name: "Comfy.MaskEditor",
init(app) {
ComfyApp.open_maskeditor =
function () {
const dlg = MaskEditorDialog.getInstance();
if(!dlg.isOpened()) {
dlg.show();
}
};
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
}
});

View File

@ -300,7 +300,7 @@ app.registerExtension({
}
}
if (widget.type === "number") {
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
}

View File

@ -14,5 +14,5 @@
window.graph = app.graph;
</script>
</head>
<body></body>
<body class="litegraph"></body>
</html>

View File

@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
//when clicked on top of a node
//and it is not interactive
if (node && this.allow_interaction && !skip_action && !this.read_only) {
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
if (!this.live_mode && !node.flags.pinned) {
this.bringToFront(node);
} //if it wasn't selected?
//not dragging mouse to connect two slots
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize
if ( !skip_action &&
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
}
//double clicking
if (is_double_click && this.selected_nodes[node.id]) {
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
//double click node
if (node.onDblClick) {
node.onDblClick( e, pos, this );
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
if (this.dragging_rectangle)
{
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
this.ds.offset[1] += delta[1] / this.ds.scale;
this.dirty_canvas = true;
this.dirty_bgcanvas = true;
} else if (this.allow_interaction && !this.read_only) {
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
if (this.connecting_node) {
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
//remove mouseover flag
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
if (show_text) {
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
}
break;
case "toggle":
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
if (show_text) {
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = w.value ? text_color : secondary_text_color;
ctx.textAlign = "right";
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(
w.name + " " + Number(w.value).toFixed(3),
w.label || w.name + " " + Number(w.value).toFixed(3),
widget_width * 0.5,
y + H * 0.7
);
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
}
ctx.fillStyle = secondary_text_color;
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillStyle = text_color;
ctx.textAlign = "right";
if (w.type == "number") {
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
//ctx.stroke();
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = text_color;
ctx.textAlign = "right";
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
event,
active_widget
) {
if (!node.widgets || !node.widgets.length) {
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
return null;
}
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
canvas.graph.add(group);
};
/**
* Determines the furthest nodes in each direction
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.getBoundaryNodes = function(nodes) {
let top = null;
let right = null;
let bottom = null;
let left = null;
for (const nID in nodes) {
const node = nodes[nID];
const [x, y] = node.pos;
const [width, height] = node.size;
if (top === null || y < top.pos[1]) {
top = node;
}
if (right === null || x + width > right.pos[0] + right.size[0]) {
right = node;
}
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
bottom = node;
}
if (left === null || x < left.pos[0]) {
left = node;
}
}
return {
"top": top,
"right": right,
"bottom": bottom,
"left": left
};
}
/**
* Determines the furthest nodes in each direction for the currently selected nodes
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
}
/**
*
* @param {LGraphNode[]} nodes a list of nodes
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
*/
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
if (!nodes) {
return;
}
const canvas = LGraphCanvas.active_canvas;
let boundaryNodes = []
if (align_to === undefined) {
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
} else {
boundaryNodes = {
"top": align_to,
"right": align_to,
"bottom": align_to,
"left": align_to
}
}
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
switch (direction) {
case "right":
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
break;
case "left":
node.pos[0] = boundaryNodes["left"].pos[0];
break;
case "top":
node.pos[1] = boundaryNodes["top"].pos[1];
break;
case "bottom":
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
break;
}
}
canvas.dirty_canvas = true;
canvas.dirty_bgcanvas = true;
};
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
}
}
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
}
}
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
var canvas = LGraphCanvas.active_canvas;
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
}*/
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align",
has_submenu: true,
callback: LGraphCanvas.onGroupAlign,
})
}
if (this._graph_stack && this._graph_stack.length > 0) {
options.push(null, {
content: "Close subgraph",
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
callback: LGraphCanvas.onMenuNodeToSubgraph
});
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align Selected To",
has_submenu: true,
callback: LGraphCanvas.onNodeAlign,
})
}
options.push(null, {
content: "Remove",
disabled: !(node.removable !== false && !node.block_delete ),

View File

@ -88,6 +88,12 @@ class ComfyApi extends EventTarget {
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
@ -163,7 +169,7 @@ class ComfyApi extends EventTarget {
if (res.status !== 200) {
throw {
response: await res.text(),
response: await res.json(),
};
}
}

View File

@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -25,6 +25,9 @@ export class ComfyApp {
* @type {serialized node object}
*/
static clipspace = null;
static clipspace_invalidate_handler = null;
static open_maskeditor = null;
static clipspace_return_node = null;
constructor() {
this.ui = new ComfyUI(this);
@ -48,6 +51,114 @@ export class ComfyApp {
this.shiftDown = false;
}
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
static onClipspaceEditorSave() {
if(ComfyApp.clipspace_return_node) {
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
}
}
static onClipspaceEditorClosed() {
ComfyApp.clipspace_return_node = null;
}
static copyToClipspace(node) {
var widgets = null;
if(node.widgets) {
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined;
var orig_imgs = undefined;
if(node.imgs != undefined) {
imgs = [];
orig_imgs = [];
for (let i = 0; i < node.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = node.imgs[i].src;
orig_imgs[i] = imgs[i];
}
}
var selectedIndex = 0;
if(node.imageIndex) {
selectedIndex = node.imageIndex;
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': node.images,
'selectedIndex': selectedIndex,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
ComfyApp.clipspace_return_node = null;
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
static pasteFromClipspace(node) {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
node.imgs = [img];
node.imageIndex = 0;
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
node.imgs = imgs;
}
}
}
}
if(node.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = node.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
node.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
app.graph.setDirtyCanvas(true);
}
}
/**
* Invoke an extension callback
* @param {keyof ComfyExtension} method The extension callback to execute
@ -137,81 +248,30 @@ export class ComfyApp {
}
}
options.push(
{
content: "Copy (Clipspace)",
callback: (obj) => {
var widgets = null;
if(this.widgets) {
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
let img = new Image();
var imgs = undefined;
if(this.imgs != undefined) {
img.src = this.imgs[0].src;
imgs = [img];
}
// prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) {
options.push({
content: "Copy (Clipspace)",
callback: (obj) => { ComfyApp.copyToClipspace(this); }
});
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': imgs,
'images': this.images
};
}
});
if(ComfyApp.clipspace != null) {
options.push({
content: "Paste (Clipspace)",
callback: () => { ComfyApp.pasteFromClipspace(this); }
});
}
if(ComfyApp.clipspace != null) {
options.push(
{
content: "Paste (Clipspace)",
callback: () => {
if(ComfyApp.clipspace != null) {
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
if (prop) {
prop.callback(value);
}
});
}
// image paste
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
var filename = "";
if(this.images && ComfyApp.clipspace.images) {
this.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.images != undefined) {
const clip_image = ComfyApp.clipspace.images[0];
if(clip_image.subfolder != '')
filename = `${clip_image.subfolder}/`;
filename += `${clip_image.filename} [${clip_image.type}]`;
}
else if(ComfyApp.clipspace.widgets != undefined) {
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
if(index_in_clip >= 0) {
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
}
}
const index = this.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
this.imgs = ComfyApp.clipspace.imgs;
this.widgets[index].value = filename;
if(this.widgets_values != undefined) {
this.widgets_values[index] = filename;
}
}
}
this.trigger('changed');
if(ComfyApp.isImageNode(this)) {
options.push({
content: "Open in MaskEditor",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
}
}
}
);
});
}
}
};
}
@ -711,16 +771,27 @@ export class ComfyApp {
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = origDrawNodeShape.apply(this, arguments);
const nodeErrors = self.lastPromptError?.node_errors[node.id];
let color = null;
let lineWidth = 1;
if (node.id === +self.runningNodeId) {
color = "#0f0";
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
color = "dodgerblue";
}
else if (self.lastPromptError != null && nodeErrors?.errors) {
color = "red";
lineWidth = 2;
}
else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) {
color = "#f0f";
lineWidth = 2;
}
if (color) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.lineWidth = lineWidth;
ctx.globalAlpha = 0.8;
ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE)
@ -747,11 +818,28 @@ export class ComfyApp {
ctx.stroke();
ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1;
}
if (self.progress) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillStyle = bgcolor;
if (self.progress && node.id === +self.runningNodeId) {
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillStyle = bgcolor;
}
// Highlight inputs that failed validation
if (nodeErrors) {
ctx.lineWidth = 2;
ctx.strokeStyle = "red";
for (const error of nodeErrors.errors) {
if (error.extra_info && error.extra_info.input_name) {
const inputIndex = node.findInputSlot(error.extra_info.input_name)
if (inputIndex !== -1) {
let pos = node.getConnectionPos(true, inputIndex);
ctx.beginPath();
ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false)
ctx.stroke();
}
}
}
}
@ -809,6 +897,17 @@ export class ComfyApp {
}
});
api.addEventListener("execution_start", ({ detail }) => {
this.lastExecutionError = null
});
api.addEventListener("execution_error", ({ detail }) => {
this.lastExecutionError = detail;
const formattedError = this.#formatExecutionError(detail);
this.ui.dialog.show(formattedError);
this.canvas.draw(true, true);
});
api.init();
}
@ -842,7 +941,9 @@ export class ComfyApp {
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
const mainCanvas = document.createElement("canvas")
mainCanvas.style.touchAction = "none"
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
@ -909,6 +1010,11 @@ export class ComfyApp {
const app = this;
// Load node definitions from the backend
const defs = await api.getNodeDefs();
await this.registerNodesFromDefs(defs);
await this.#invokeExtensionsAsync("registerCustomNodes");
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
@ -954,7 +1060,8 @@ export class ComfyApp {
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output);
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();
@ -980,8 +1087,6 @@ export class ComfyApp {
LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category;
}
await this.#invokeExtensionsAsync("registerCustomNodes");
}
/**
@ -1180,6 +1285,43 @@ export class ComfyApp {
return { workflow, output };
}
#formatPromptError(error) {
if (error == null) {
return "(unknown error)"
}
else if (typeof error === "string") {
return error;
}
else if (error.stack && error.message) {
return error.toString()
}
else if (error.response) {
let message = error.response.error.message;
if (error.response.error.details)
message += ": " + error.response.error.details;
for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) {
message += "\n" + nodeError.class_type + ":"
for (const errorReason of nodeError.errors) {
message += "\n - " + errorReason.message + ": " + errorReason.details
}
}
return message
}
return "(unknown error)"
}
#formatExecutionError(error) {
if (error == null) {
return "(unknown error)"
}
const traceback = error.traceback.join("")
const nodeId = error.node_id
const nodeType = error.node_type
return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}`
}
async queuePrompt(number, batchCount = 1) {
this.#queueItems.push({ number, batchCount });
@ -1187,8 +1329,10 @@ export class ComfyApp {
if (this.#processingQueue) {
return;
}
this.#processingQueue = true;
this.lastPromptError = null;
try {
while (this.#queueItems.length) {
({ number, batchCount } = this.#queueItems.pop());
@ -1199,7 +1343,12 @@ export class ComfyApp {
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
const formattedError = this.#formatPromptError(error)
this.ui.dialog.show(formattedError);
if (error.response) {
this.lastPromptError = error.response;
this.canvas.draw(true, true);
}
break;
}
@ -1245,6 +1394,11 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
}
}
}
@ -1273,14 +1427,19 @@ export class ComfyApp {
const def = defs[node.type];
// HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes,
// and additional work is needed to consider the primitive logic in the refresh logic.
if(!def)
continue;
for(const widgetNum in node.widgets) {
const widget = node.widgets[widgetNum]
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
widget.options.values = def["input"]["required"][widget.name][0];
if(!widget.options.values.includes(widget.value)) {
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
widget.value = widget.options.values[0];
widget.callback(widget.value);
}
}
}
@ -1292,6 +1451,8 @@ export class ComfyApp {
*/
clean() {
this.nodeOutputs = {};
this.lastPromptError = null;
this.lastExecutionError = null;
}
}

View File

@ -47,12 +47,29 @@ export function getPngMetadata(file) {
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const safetensorsData = new Uint8Array(event.target.result);
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
});
}
export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) {
const embeddings = await api.getEmbeddings();
const opts = parameters
.substr(p)
.split("\n")[1]
.split(",")
.reduce((p, n) => {
const s = n.split(":");

View File

@ -465,7 +465,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png",
accept: ".json,image/png,.latent",
style: { display: "none" },
parent: document.body,
onchange: () => {
@ -581,6 +581,7 @@ export class ComfyUI {
}),
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
if (!confirmClear.value || confirm("Clear workflow?")) {
app.clean();

View File

@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
var v = valueControl.value;
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
switch (v) {
case "increment":
current_index += 1;
break;
case "decrement":
current_index -= 1;
break;
case "randomize":
current_index = Math.floor(Math.random() * current_length);
default:
break;
}
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
return valueControl;
};
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const t = ctx.getTransform();
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
Object.assign(this.inputEl.style, {
left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
background: (!node.color)?'':node.color,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
transformOrigin: "0 0",
transform: transform,
left: "0px",
top: "0px",
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`,
});
this.inputEl.hidden = !visible;
},
@ -266,10 +297,46 @@ export const ComfyWidgets = {
node.imgs = [img];
app.graph.setDirtyCanvas(true);
};
img.src = `/view?filename=${name}&type=input`;
let folder_separator = name.lastIndexOf("/");
let subfolder = "";
if (folder_separator > -1) {
subfolder = name.substring(0, folder_separator);
name = name.substring(folder_separator + 1);
}
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
node.setSizeForImage?.();
}
var default_value = imageWidget.value;
Object.defineProperty(imageWidget, "value", {
set : function(value) {
this._real_value = value;
},
get : function() {
let value = "";
if (this._real_value) {
value = this._real_value;
} else {
return default_value;
}
if (value.filename) {
let real_value = value;
value = "";
if (real_value.subfolder) {
value = real_value.subfolder + "/";
}
value += real_value.filename;
if(real_value.type && real_value.type !== "input")
value += ` [${real_value.type}]`;
}
return value;
}
});
// Add our own callback to the combo widget to render an image when it changes
const cb = node.callback;
imageWidget.callback = function () {

View File

@ -39,6 +39,8 @@ body {
padding: 2px;
resize: none;
border: none;
box-sizing: border-box;
font-size: 10px;
}
.comfy-modal {
@ -287,6 +289,11 @@ button.comfy-queue-btn {
/* Context menu */
.litegraph .dialog {
z-index: 1;
font-family: Arial;
}
.litegraph .litemenu-entry.has_submenu {
position: relative;
padding-right: 20px;
@ -329,6 +336,7 @@ button.comfy-queue-btn {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
display: block;
}
.litegraph.litesearchbox input,