diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml index 15322c86a..2d6048a23 100644 --- a/.github/workflows/windows_release_cu118_package.yml +++ b/.github/workflows/windows_release_cu118_package.yml @@ -30,6 +30,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - shell: bash run: | cd .. diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b6a18ec0a..767a7216b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - uses: actions/setup-python@v4 with: python-version: '3.11.3' diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..43877fb83 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,111 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..26930428f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + h = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60f..91e7d60ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -146,6 +146,41 @@ class ResnetBlock(nn.Module): return x+h +def slice_attention(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1])**(-0.5)) + + mem_free_total = model_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except model_management.OOM_EXCEPTION as e: + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again", steps) + + return r1 class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -183,48 +218,15 @@ class AttnBlock(nn.Module): # compute attention b,c,h,w = q.shape - scale = (int(c)**(-0.5)) q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw v = v.reshape(b,c,h*w) - r1 = torch.zeros_like(k, device=q.device) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = torch.bmm(q[:, i:end], k) * scale - - s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) - del s1 - - r1[:, :, i:end] = torch.bmm(v, s2) - del s2 - break - except model_management.OOM_EXCEPTION as e: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - + r1 = slice_attention(q, k, v) h_ = r1.reshape(b,c,h,w) del r1 - h_ = self.proj_out(h_) return x+h_ @@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,23 +1,29 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,30 +37,80 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - total_vram = 4097 #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + +def get_torch_device(): + global xpu_available + global directml_enabled + global cpu_state + if directml_enabled: + global directml_device + return directml_device + if cpu_state == CPUState.MPS: + return torch.device("mps") + if cpu_state == CPUState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -92,6 +148,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -102,32 +159,42 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") +def get_torch_device_name(device): + if hasattr(device, 'type'): + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + + current_loaded_model = None current_gpu_controlnets = [] @@ -173,22 +240,29 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.DISABLED: pass - elif vram_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True @@ -197,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -233,22 +307,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type @@ -258,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -330,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -367,7 +426,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global cpu_state + if cpu_state == CPUState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy/samplers.py b/comfy/samplers.py index dcf93cca2..1fb928f8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: @@ -362,19 +386,8 @@ def resolve_cond_masks(conditions, h, w, device): else: box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) - # Make sure the height and width are divisible by 8 - if X % 8 != 0: - newx = X // 8 * 8 - W = W + (X - newx) - X = newx - if Y % 8 != 0: - newy = Y // 8 * 8 - H = H + (Y - newy) - Y = newy - if H % 8 != 0: - H = H + (8 - (H % 8)) - if W % 8 != 0: - W = W + (8 - (W % 8)) + H = max(8, H) + W = max(8, W) area = (int(H), int(W), int(Y), int(X)) modified['area'] = area @@ -482,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -519,6 +532,8 @@ class KSampler: if self.scheduler == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": diff --git a/comfy/sd.py b/comfy/sd.py index 7ffaaf3ec..2f92b0d34 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert def load_torch_file(ckpt): if ckpt.lower().endswith(".safetensors"): @@ -324,15 +325,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n @@ -553,10 +568,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() @@ -630,12 +651,9 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor @@ -652,7 +670,7 @@ def resize_image_to(tensor, target_latent_tensor, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -661,6 +679,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -672,7 +691,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -694,6 +715,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -724,7 +748,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -772,7 +796,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -789,7 +813,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -819,7 +843,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: @@ -843,10 +871,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() @@ -1070,7 +1102,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 0221fff83..87e3d859e 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -56,7 +56,12 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - return self.op(x) + if not self.use_conv: + padding = [x.shape[2] % 2, x.shape[3] % 2] + self.op.padding = padding + + x = self.op(x) + return x class ResnetBlock(nn.Module): diff --git a/comfy/utils.py b/comfy/utils.py index 09e05d4ed..4e84e870b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,16 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: @@ -46,6 +51,88 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + +def bislerp(samples, width, height): + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] + + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) + + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms + + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 + + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) + so = torch.sin(omega) + + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) + + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) + + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) + + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) + + result = slerp(pass_1, pass_2, ratios) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) + return result + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -61,7 +148,11 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if upscale_method == "bislerp": + return bislerp(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 000000000..f4d52aa1e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 000000000..d7a129696 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 000000000..477e81f9d --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 000000000..dec169520 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 000000000..f9ce7f7a6 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 000000000..731a25f75 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 000000000..4260fb7c9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05dd..b50db7c24 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc4..d7bc5d227 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d1..2e66e6247 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47a..1906c0c7f 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ] diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..3be141dfe 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,6 +59,12 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) +def gaussian_kernel(kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + class Blur: def __init__(self): pass @@ -88,12 +94,6 @@ class Blur: CATEGORY = "image/postprocessing" - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) @@ -101,10 +101,11 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) @@ -167,9 +168,15 @@ class Sharpen: "max": 31, "step": 1 }), - "alpha": ("FLOAT", { + "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, "max": 5.0, "step": 0.1 }), @@ -181,21 +188,21 @@ class Sharpen: CATEGORY = "image/postprocessing" - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): if sharpen_radius == 0: return (image,) batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) diff --git a/execution.py b/execution.py index c19c10bc6..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -26,27 +27,96 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + +def format_value(x): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -55,23 +125,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) - obj = class_def() - - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -105,7 +216,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -144,10 +256,53 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -155,6 +310,10 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -169,105 +328,250 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) - executed = set() - try: - to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] - - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] - - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) - except Exception as e: - print(traceback.format_exc()) - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + executed = set() + output_node_id = None + to_execute = [] + + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] + + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] + + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) - if r[0] == False: - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type, + "linked_node": val + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) + + validated[unique_id] = ret + return ret + +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ def validate_prompt(prompt): outputs = set() @@ -277,34 +581,86 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] + node_errors = {} + validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "exception_type": exception_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: - good_outputs.add(x) + if valid is True: + good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "") + error = { + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: @@ -340,8 +696,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..a1bf1444d 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,8 @@ import os +import time -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} @@ -38,6 +32,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -118,12 +114,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -132,19 +134,90 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + return (sorted(list(output_list)), output_folders, time.perf_counter()) +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if os.path.isdir(x): + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return list(out[0]) + +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/main.py b/main.py index eb97a2fb8..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -33,8 +33,8 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) - q.task_done(item_id, e.outputs) + e.execute(item[2], item[1], item[3], item[4]) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) diff --git a/nodes.py b/nodes.py index 20776bb37..7c3bec996 100644 --- a/nodes.py +++ b/nodes.py @@ -6,16 +6,18 @@ import json import hashlib import traceback import math +import time -from PIL import Image +from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -28,6 +30,7 @@ import importlib import folder_paths + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -145,9 +148,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -160,9 +160,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -175,9 +172,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -202,9 +196,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -219,9 +210,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} @@ -260,6 +248,81 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) + + return {} + + +class LoadLatent: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") + samples = {"samples": latent["latent_tensor"].float()} + return (samples, ) + + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -296,7 +359,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -306,12 +372,12 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: @@ -360,6 +426,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {}) return (model_lora, clip_lora) @@ -517,9 +586,11 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) - print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) @@ -624,6 +695,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() @@ -706,22 +780,61 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: - upscale_methods = ["nearest-exact", "bilinear", "area"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] crop_methods = ["disabled", "center"] @classmethod @@ -740,6 +853,25 @@ class LatentUpscale: s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) +class LatentUpscaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, samples, upscale_method, scale_by): + s = samples.copy() + width = round(samples["samples"].shape[3] * scale_by) + height = round(samples["samples"].shape[2] * scale_by) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") + return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -872,7 +1004,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): @@ -882,8 +1014,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -978,39 +1110,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() @@ -1049,8 +1149,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), )}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1060,6 +1161,7 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1090,9 +1192,10 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), ), - "channel": (s._color_channels, ),} + {"image": (sorted(files), ), + "channel": (s._color_channels, ), } } CATEGORY = "mask" @@ -1102,6 +1205,7 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") mask = None @@ -1244,7 +1348,9 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentUpscaleBy": LatentUpscaleBy, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1282,6 +1388,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1319,7 +1428,10 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", + "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1351,14 +1463,18 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1367,11 +1483,25 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - load_custom_node(module_path) + if module_path.endswith(".disabled"): continue + time_before = time.perf_counter() + success = load_custom_node(module_path) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) + + if len(node_import_times) > 0: + print("\nImport times for custom nodes:") + for n in sorted(node_import_times): + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) + print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", diff --git a/server.py b/server.py index 1c5c17916..5be822a6f 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,9 @@ import execution import uuid import json import glob +from PIL import Image +from io import BytesIO + try: import aiohttp from aiohttp import web @@ -19,7 +22,8 @@ except ImportError: import mimetypes from comfy.cli_args import args - +import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -78,7 +82,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -110,49 +114,96 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) - @routes.post("/upload/image") - async def upload_image(request): - post = await request.post() + def get_dir_by_type(dir_type): + if dir_type is None: + dir_type = "input" + + if dir_type == "input": + type_dir = folder_paths.get_input_directory() + elif dir_type == "temp": + type_dir = folder_paths.get_temp_directory() + elif dir_type == "output": + type_dir = folder_paths.get_output_directory() + + return type_dir, dir_type + + def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") - if post.get("type") is None: - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "input": - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "temp": - upload_dir = folder_paths.get_temp_directory() - elif post.get("type") == "output": - upload_dir = folder_paths.get_output_directory() - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) + image_upload_type = post.get("type") + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename if not filename: return web.Response(status=400) + subfolder = post.get("subfolder", "") + full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + + if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + return web.Response(status=400) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + split = os.path.splitext(filename) - i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 + filepath = os.path.join(full_output_folder, filename) - filepath = os.path.join(upload_dir, filename) + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 - with open(filepath, "wb") as f: - f.write(image.file.read()) - - return web.json_response({"name" : filename}) + if image_save_function is not None: + image_save_function(image, post, filepath) + else: + with open(filepath, "wb") as f: + f.write(image.file.read()) + + return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: return web.Response(status=400) + @routes.post("/upload/image") + async def upload_image(request): + post = await request.post() + return image_upload(post) + + @routes.post("/upload/mask") + async def upload_mask(request): + post = await request.post() + + def image_save_function(image, post, filepath): + original_pil = Image.open(post.get("original_image").file).convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + original_pil.save(filepath, compress_level=4) + + return image_upload(post, image_save_function) @routes.get("/view") async def view_image(request): if "filename" in request.rel_url.query: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + filename = request.rel_url.query["filename"] + filename,output_dir = folder_paths.annotated_filepath(filename) + + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) @@ -162,35 +213,132 @@ class PromptServer(): return web.Response(status=403) output_dir = full_output_dir - filename = request.rel_url.query["filename"] filename = os.path.basename(filename) file = os.path.join(output_dir, filename) if os.path.isfile(file): - return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - + if 'channel' not in request.rel_url.query: + channel = 'rgba' + else: + channel = request.rel_url.query["channel"] + + if channel == 'rgb': + with Image.open(file) as img: + if img.mode == "RGBA": + r, g, b, a = img.split() + new_img = Image.merge('RGB', (r, g, b)) + else: + new_img = img.convert("RGB") + + buffer = BytesIO() + new_img.save(buffer, format='PNG') + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + + elif channel == 'a': + with Image.open(file) as img: + if img.mode == "RGBA": + _, _, _, a = img.split() + else: + a = Image.new('L', img.size, 255) + + # alpha img + alpha_img = Image.new('RGBA', img.size) + alpha_img.putalpha(a) + alpha_buffer = BytesIO() + alpha_img.save(alpha_buffer, format='PNG') + alpha_buffer.seek(0) + + return web.Response(body=alpha_buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + else: + return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) + return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + + @routes.get("/system_stats") + async def get_queue(request): + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) + def node_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = x - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info + out[x] = node_info(x) + return web.json_response(out) + + @routes.get("/object_info/{node_class}") + async def get_object_info_node(request): + node_class = request.match_info.get("node_class", None) + out = {} + if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + out[node_class] = node_info(node_class) return web.json_response(out) @routes.get("/history") @@ -232,14 +380,16 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) + else: + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() @@ -249,9 +399,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") @@ -275,7 +425,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): diff --git a/web/extensions/core/clipspace.js b/web/extensions/core/clipspace.js new file mode 100644 index 000000000..adb5877ea --- /dev/null +++ b/web/extensions/core/clipspace.js @@ -0,0 +1,166 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; + +export class ClipspaceDialog extends ComfyDialog { + static items = []; + static instance = null; + + static registerButton(name, contextPredicate, callback) { + const item = + $el("button", { + type: "button", + textContent: name, + contextPredicate: contextPredicate, + onclick: callback + }) + + ClipspaceDialog.items.push(item); + } + + static invalidatePreview() { + if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) { + const img_preview = document.getElementById("clipspace_preview"); + if(img_preview) { + img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + img_preview.style.maxHeight = "100%"; + img_preview.style.maxWidth = "100%"; + } + } + } + + static invalidate() { + if(ClipspaceDialog.instance) { + const self = ClipspaceDialog.instance; + // allow reconstruct controls when copying from non-image to image content. + const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]); + + if(self.element) { + // update + self.element.removeChild(self.element.firstChild); + self.element.appendChild(children); + } + else { + // new + self.element = $el("div.comfy-modal", { parent: document.body }, [children,]); + } + + if(self.element.children[0].children.length <= 1) { + self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."])); + } + + ClipspaceDialog.invalidatePreview(); + } + } + + constructor() { + super(); + } + + createButtons(self) { + const buttons = []; + + for(let idx in ClipspaceDialog.items) { + const item = ClipspaceDialog.items[idx]; + if(!item.contextPredicate || item.contextPredicate()) + buttons.push(ClipspaceDialog.items[idx]); + } + + buttons.push( + $el("button", { + type: "button", + textContent: "Close", + onclick: () => { this.close(); } + }) + ); + + return buttons; + } + + createImgSettings() { + if(ComfyApp.clipspace.imgs) { + const combo_items = []; + const imgs = ComfyApp.clipspace.imgs; + + for(let i=0; i < imgs.length; i++) { + combo_items.push($el("option", {value:i}, [`${i}`])); + } + + const combo1 = $el("select", + {id:"clipspace_img_selector", onchange:(event) => { + ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex; + ClipspaceDialog.invalidatePreview(); + } }, combo_items); + + const row1 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]), + $el("td", {}, [combo1]) + ]); + + + const combo2 = $el("select", + {id:"clipspace_img_paste_mode", onchange:(event) => { + ComfyApp.clipspace['img_paste_mode'] = event.target.value; + } }, + [ + $el("option", {value:'selected'}, 'selected'), + $el("option", {value:'all'}, 'all') + ]); + combo2.value = ComfyApp.clipspace['img_paste_mode']; + + const row2 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]), + $el("td", {}, [combo2]) + ]); + + const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'}, + [ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]); + + const row3 = + $el("tr", {}, [td]); + + return $el("table", {}, [row1, row2, row3]); + } + else { + return []; + } + } + + createImgPreview() { + if(ComfyApp.clipspace.imgs) { + return $el("img",{id:"clipspace_preview", ondragstart:() => false}); + } + else + return []; + } + + show() { + const img_preview = document.getElementById("clipspace_preview"); + ClipspaceDialog.invalidate(); + + this.element.style.display = "block"; + } +} + +app.registerExtension({ + name: "Comfy.Clipspace", + init(app) { + app.openClipspace = + function () { + if(!ClipspaceDialog.instance) { + ClipspaceDialog.instance = new ClipspaceDialog(app); + ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate; + } + + if(ComfyApp.clipspace) { + ClipspaceDialog.instance.show(); + } + else + app.ui.dialog.show("Clipspace is Empty!"); + }; + } +}); \ No newline at end of file diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 2f2238a2b..bfcd847a3 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,7 +174,7 @@ const els = {} // const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, - init() { + addCustomNodeDefs(node_defs) { const sortObjectKeys = (unordered) => { return Object.keys(unordered).sort().reduce((obj, key) => { obj[key] = unordered[key]; @@ -182,10 +182,10 @@ app.registerExtension({ }, {}); }; - const getSlotTypes = async () => { + function getSlotTypes() { var types = []; - const defs = await api.getNodeDefs(); + const defs = node_defs; for (const nodeId in defs) { const nodeData = defs[nodeId]; @@ -212,8 +212,8 @@ app.registerExtension({ return types; }; - const completeColorPalette = async (colorPalette) => { - var types = await getSlotTypes(); + function completeColorPalette(colorPalette) { + var types = getSlotTypes(); for (const type of types) { if (!colorPalette.colors.node_slot[type]) { diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js new file mode 100644 index 000000000..4b0c12747 --- /dev/null +++ b/web/extensions/core/maskeditor.js @@ -0,0 +1,648 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; +import { ClipspaceDialog } from "/extensions/core/clipspace.js"; + +// Helper function to convert a data URL to a Blob object +function dataURLToBlob(dataURL) { + const parts = dataURL.split(';base64,'); + const contentType = parts[0].split(':')[1]; + const byteString = atob(parts[1]); + const arrayBuffer = new ArrayBuffer(byteString.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < byteString.length; i++) { + uint8Array[i] = byteString.charCodeAt(i); + } + return new Blob([arrayBuffer], { type: contentType }); +} + +function loadedImageToBlob(image) { + const canvas = document.createElement('canvas'); + + canvas.width = image.width; + canvas.height = image.height; + + const ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0); + + const dataURL = canvas.toDataURL('image/png', 1); + const blob = dataURLToBlob(dataURL); + + return blob; +} + +async function uploadMask(filepath, formData) { + await fetch('/upload/mask', { + method: 'POST', + body: formData + }).then(response => {}).catch(error => { + console.error('Error:', error); + }); + + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; + + ClipspaceDialog.invalidatePreview(); +} + +function prepareRGB(image, backupCanvas, backupCtx) { + // paste mask data into alpha channel + backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); +} + +class MaskEditorDialog extends ComfyDialog { + static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + + constructor() { + super(); + this.element = $el("div.comfy-modal", { parent: document.body }, + [ $el("div.comfy-modal-content", + [...this.createButtons()]), + ]); + } + + createButtons() { + return []; + } + + createButton(name, callback) { + var button = document.createElement("button"); + button.innerText = name; + button.addEventListener("click", callback); + return button; + } + + createLeftButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "left"; + button.style.marginRight = "4px"; + return button; + } + + createRightButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "right"; + button.style.marginLeft = "4px"; + return button; + } + + createLeftSlider(self, name, callback) { + const divElement = document.createElement('div'); + divElement.id = "maskeditor-slider"; + divElement.style.cssFloat = "left"; + divElement.style.fontFamily = "sans-serif"; + divElement.style.marginRight = "4px"; + divElement.style.color = "var(--input-text)"; + divElement.style.backgroundColor = "var(--comfy-input-bg)"; + divElement.style.borderRadius = "8px"; + divElement.style.borderColor = "var(--border-color)"; + divElement.style.borderStyle = "solid"; + divElement.style.fontSize = "15px"; + divElement.style.height = "21px"; + divElement.style.padding = "1px 6px"; + divElement.style.display = "flex"; + divElement.style.position = "relative"; + divElement.style.top = "2px"; + self.brush_slider_input = document.createElement('input'); + self.brush_slider_input.setAttribute('type', 'range'); + self.brush_slider_input.setAttribute('min', '1'); + self.brush_slider_input.setAttribute('max', '100'); + self.brush_slider_input.setAttribute('value', '10'); + const labelElement = document.createElement("label"); + labelElement.textContent = name; + + divElement.appendChild(labelElement); + divElement.appendChild(self.brush_slider_input); + + self.brush_slider_input.addEventListener("change", callback); + + return divElement; + } + + setlayout(imgCanvas, maskCanvas) { + const self = this; + + // If it is specified as relative, using it only as a hidden placeholder for padding is recommended + // to prevent anomalies where it exceeds a certain size and goes outside of the window. + var placeholder = document.createElement("div"); + placeholder.style.position = "relative"; + placeholder.style.height = "50px"; + + var bottom_panel = document.createElement("div"); + bottom_panel.style.position = "absolute"; + bottom_panel.style.bottom = "0px"; + bottom_panel.style.left = "20px"; + bottom_panel.style.right = "20px"; + bottom_panel.style.height = "50px"; + + var brush = document.createElement("div"); + brush.id = "brush"; + brush.style.backgroundColor = "transparent"; + brush.style.outline = "1px dashed black"; + brush.style.boxShadow = "0 0 0 1px white"; + brush.style.borderRadius = "50%"; + brush.style.MozBorderRadius = "50%"; + brush.style.WebkitBorderRadius = "50%"; + brush.style.position = "absolute"; + brush.style.zIndex = 8889; + brush.style.pointerEvents = "none"; + this.brush = brush; + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + document.body.appendChild(brush); + + var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + self.brush_size = event.target.value; + self.updateBrushPreview(self, null, null); + }); + var clearButton = this.createLeftButton("Clear", + () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); + }); + var cancelButton = this.createRightButton("Cancel", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.close(); + }); + + this.saveButton = this.createRightButton("Save", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.save(); + }); + + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + + bottom_panel.appendChild(clearButton); + bottom_panel.appendChild(this.saveButton); + bottom_panel.appendChild(cancelButton); + bottom_panel.appendChild(brush_size_slider); + + imgCanvas.style.position = "relative"; + imgCanvas.style.top = "200"; + imgCanvas.style.left = "0"; + + maskCanvas.style.position = "absolute"; + } + + show() { + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); + + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; + + this.setlayout(imgCanvas, maskCanvas); + + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); + + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; + } + + setImages(imgCanvas, backupCanvas) { + const imgCtx = imgCanvas.getContext('2d'); + const backupCtx = backupCanvas.getContext('2d'); + const maskCtx = this.maskCtx; + const maskCanvas = this.maskCanvas; + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + + // image load + const orig_image = new Image(); + window.addEventListener("resize", () => { + // repositioning + imgCanvas.width = window.innerWidth - 250; + imgCanvas.height = window.innerHeight - 200; + + // redraw image + let drawWidth = orig_image.width; + let drawHeight = orig_image.height; + if (orig_image.width > imgCanvas.width) { + drawWidth = imgCanvas.width; + drawHeight = (drawWidth / orig_image.width) * orig_image.height; + } + + if (drawHeight > imgCanvas.height) { + drawHeight = imgCanvas.height; + drawWidth = (drawHeight / orig_image.height) * orig_image.width; + } + + imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); + + // update mask + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); + maskCanvas.width = drawWidth; + maskCanvas.height = drawHeight; + maskCanvas.style.top = imgCanvas.offsetTop + "px"; + maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); + }); + + const filepath = ComfyApp.clipspace.images; + + const touched_image = new Image(); + + touched_image.onload = function() { + backupCanvas.width = touched_image.width; + backupCanvas.height = touched_image.height; + + prepareRGB(touched_image, backupCanvas, backupCtx); + }; + + const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) + alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.set('channel', 'a'); + touched_image.src = alpha_url; + + // original image load + orig_image.onload = function() { + window.dispatchEvent(new Event('resize')); + }; + + const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); + rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.set('channel', 'rgb'); + orig_image.src = rgb_url; + this.image = orig_image; + } + + setEventHandler(maskCanvas) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + const self = this; + maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + } + + brush_size = 10; + drawing_mode = false; + lastx = -1; + lasty = -1; + lasttime = 0; + + static handleKeyDown(event) { + const self = MaskEditorDialog.instance; + if (event.key === ']') { + self.brush_size = Math.min(self.brush_size+2, 100); + } else if (event.key === '[') { + self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); + } + + self.updateBrushPreview(self); + } + + static handlePointerUp(event) { + event.preventDefault(); + MaskEditorDialog.instance.drawing_mode = false; + } + + updateBrushPreview(self) { + const brush = self.brush; + + var centerX = self.cursorX; + var centerY = self.cursorY; + + brush.style.width = self.brush_size * 2 + "px"; + brush.style.height = self.brush_size * 2 + "px"; + brush.style.left = (centerX - self.brush_size) + "px"; + brush.style.top = (centerY - self.brush_size) + "px"; + } + + handleWheelEvent(self, event) { + if(event.deltaY < 0) + self.brush_size = Math.min(self.brush_size+2, 100); + else + self.brush_size = Math.max(self.brush_size-2, 1); + + self.brush_slider_input.value = self.brush_size; + + self.updateBrushPreview(self); + } + + draw_move(self, event) { + event.preventDefault(); + + this.cursorX = event.pageX; + this.cursorY = event.pageY; + + self.updateBrushPreview(self); + + if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { + var diff = performance.now() - self.lasttime; + + const maskRect = self.maskCanvas.getBoundingClientRect(); + + var x = event.offsetX; + var y = event.offsetY + + if(event.offsetX == null) { + x = event.targetTouches[0].clientX - maskRect.left; + } + + if(event.offsetY == null) { + y = event.targetTouches[0].clientY - maskRect.top; + } + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ + // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !this.drawing_mode) + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + } + + handlePointerDown(self, event) { + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + + if ([0, 2, 5].includes(event.button)) { + self.drawing_mode = true; + + event.preventDefault(); + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + self.maskCtx.beginPath(); + if (event.button == 0) { + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + } else { + self.maskCtx.globalCompositeOperation = "destination-out"; + } + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + self.lasttime = performance.now(); + } + } + + async save() { + const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.drawImage(this.maskCanvas, + 0, 0, this.maskCanvas.width, this.maskCanvas.height, + 0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // paste mask data into alpha channel + const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); + + const formData = new FormData(); + const filename = "clipspace-mask-" + performance.now() + ".png"; + + const item = + { + "filename": filename, + "subfolder": "clipspace", + "type": "input", + }; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[0] = item; + + if(ComfyApp.clipspace.widgets) { + const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + + if(index >= 0) + ComfyApp.clipspace.widgets[index].value = item; + } + + const dataURL = this.backupCanvas.toDataURL(); + const blob = dataURLToBlob(dataURL); + + const original_blob = loadedImageToBlob(this.image); + + formData.append('image', blob, filename); + formData.append('original_image', original_blob); + formData.append('type', "input"); + formData.append('subfolder', "clipspace"); + + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); + this.close(); + } +} + +app.registerExtension({ + name: "Comfy.MaskEditor", + init(app) { + ComfyApp.open_maskeditor = + function () { + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } + }; + + const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); + } +}); \ No newline at end of file diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; @@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..378165b3a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -88,6 +88,12 @@ class ComfyApi extends EventTarget { case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; default: if (this.#registered.has(msg.type)) { this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); @@ -163,7 +169,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 245605484..8a9c7ca49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -25,6 +25,9 @@ export class ComfyApp { * @type {serialized node object} */ static clipspace = null; + static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -48,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -137,81 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - let img = new Image(); - var imgs = undefined; - if(this.imgs != undefined) { - img.src = this.imgs[0].src; - imgs = [img]; - } + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': imgs, - 'images': this.images - }; - } - }); + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace != null) { - if(ComfyApp.clipspace.widgets != null && this.widgets != null) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop) { - prop.callback(value); - } - }); - } - - // image paste - if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { - var filename = ""; - if(this.images && ComfyApp.clipspace.images) { - this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.images != undefined) { - const clip_image = ComfyApp.clipspace.images[0]; - if(clip_image.subfolder != '') - filename = `${clip_image.subfolder}/`; - filename += `${clip_image.filename} [${clip_image.type}]`; - } - else if(ComfyApp.clipspace.widgets != undefined) { - const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); - if(index_in_clip >= 0) { - filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; - } - } - - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { - this.imgs = ComfyApp.clipspace.imgs; - - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; - } - } - } - this.trigger('changed'); + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - } - ); + }); + } } }; } @@ -711,16 +771,27 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -747,11 +818,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -809,6 +897,17 @@ export class ComfyApp { } }); + api.addEventListener("execution_start", ({ detail }) => { + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + api.init(); } @@ -842,7 +941,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); @@ -909,6 +1010,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + await this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -954,7 +1060,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); @@ -980,8 +1087,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** @@ -1180,6 +1285,43 @@ export class ComfyApp { return { workflow, output }; } + #formatPromptError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1187,8 +1329,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1199,7 +1343,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + const formattedError = this.#formatPromptError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1245,6 +1394,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } @@ -1273,14 +1427,19 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(!widget.options.values.includes(widget.value)) { + if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; + widget.callback(widget.value); } } } @@ -1292,6 +1451,8 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.lastPromptError = null; + this.lastExecutionError = null; } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,12 +47,29 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5accc9d86..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { @@ -581,6 +581,7 @@ export class ComfyUI { }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index cd471bc93..82168b08b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; @@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) { computeSize(node.size); } const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const t = ctx.getTransform(); const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + Object.assign(this.inputEl.style, { - left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, - width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - background: (!node.color)?'':node.color, - height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, + transformOrigin: "0 0", + transform: transform, + left: "0px", + top: "0px", + width: `${widgetWidth - (margin * 2)}px`, + height: `${this.parent.inputHeight - (margin * 2)}px`, position: "absolute", + background: (!node.color)?'':node.color, color: (!node.color)?'':'white', zIndex: app.graph._nodes.indexOf(node), - fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; }, @@ -266,10 +297,46 @@ export const ComfyWidgets = { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); + } + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; node.setSizeForImage?.(); } + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; + } + }); + // Add our own callback to the combo widget to render an image when it changes const cb = node.callback; imageWidget.callback = function () { diff --git a/web/style.css b/web/style.css index df220cc02..47571a16e 100644 --- a/web/style.css +++ b/web/style.css @@ -39,6 +39,8 @@ body { padding: 2px; resize: none; border: none; + box-sizing: border-box; + font-size: 10px; } .comfy-modal { @@ -287,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; @@ -329,6 +336,7 @@ button.comfy-queue-btn { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input,