mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 00:00:57 +08:00
Merge branch 'master' of https://github.com/comfyanonymous/ComfyUI into dpr
This commit is contained in:
commit
9cd57a9dab
@ -30,6 +30,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
|
||||
@ -17,6 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11.3'
|
||||
|
||||
@ -56,6 +56,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| Q | Toggle visibility of the queue |
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for MacOS users
|
||||
|
||||
|
||||
@ -5,17 +5,17 @@ import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
from ..ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
from ..ldm.modules.attention import SpatialTransformer
|
||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ..ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
|
||||
@ -7,6 +7,7 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
@ -30,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test
|
||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from ldm.modules.attention import CrossAttention
|
||||
from .ldm.modules.attention import CrossAttention
|
||||
from inspect import isfunction
|
||||
|
||||
|
||||
@ -242,14 +242,28 @@ class Gligen(nn.Module):
|
||||
self.position_net = position_net
|
||||
self.key_dim = key_dim
|
||||
self.max_objs = 30
|
||||
self.lowvram = False
|
||||
|
||||
def _set_position(self, boxes, masks, positive_embeddings):
|
||||
if self.lowvram == True:
|
||||
self.position_net.to(boxes.device)
|
||||
|
||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(key, x):
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
module.cpu()
|
||||
return r
|
||||
return func_lowvram
|
||||
else:
|
||||
def func(key, x):
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
@ -294,8 +308,11 @@ class Gligen(nn.Module):
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_lowvram(self, value=True):
|
||||
self.lowvram = value
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
self.lowvram = False
|
||||
|
||||
def get_models(self):
|
||||
return [self]
|
||||
|
||||
@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
@ -3,11 +3,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
|
||||
# class AutoencoderKL(pl.LightningModule):
|
||||
class AutoencoderKL(torch.nn.Module):
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
@ -19,12 +19,12 @@ from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
# from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ..autoencoder import IdentityFirstStage, AutoencoderKL
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
from .ddim import DDIMSampler
|
||||
|
||||
|
||||
__conditioning_keys__ = {'concat': 'c_concat',
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
from .diffusionmodules.util import checkpoint
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
@ -21,7 +21,7 @@ if model_management.xformers_enabled():
|
||||
import os
|
||||
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
||||
|
||||
from cli_args import args
|
||||
from comfy.cli_args import args
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
if current_index is not None:
|
||||
transformer_options["current_index"] += 1
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
from ..attention import MemoryEfficientCrossAttention
|
||||
from comfy import model_management
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
return x+h
|
||||
|
||||
def slice_attention(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1])**(-0.5))
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
return r1
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
scale = (int(c)**(-0.5))
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
from .util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import (
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.util import exists
|
||||
from ..attention import SpatialTransformer
|
||||
from comfy.ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
@ -805,13 +818,13 @@ class UNetModel(nn.Module):
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
|
||||
@ -828,7 +841,7 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = module(h, emb, context, transformer_options, output_shape)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
||||
@ -3,8 +3,8 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ldm.util import default
|
||||
from .util import extract_into_tensor, make_beta_schedule
|
||||
from comfy.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
|
||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ..diffusionmodules.openaimodel import Timestep
|
||||
import torch
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
|
||||
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from cli_args import args
|
||||
from comfy.cli_args import args
|
||||
|
||||
class VRAMState(Enum):
|
||||
CPU = 0
|
||||
@ -127,6 +127,32 @@ if args.cpu:
|
||||
|
||||
print(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_torch_device_name(device):
|
||||
if hasattr(device, 'type'):
|
||||
return "{}".format(device.type)
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
try:
|
||||
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||
except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
current_gpu_controlnets = []
|
||||
@ -201,6 +227,9 @@ def load_controlnet_gpu(control_models):
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
for m in control_models:
|
||||
if hasattr(m, 'set_lowvram'):
|
||||
m.set_lowvram(True)
|
||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
return
|
||||
|
||||
@ -230,22 +259,6 @@ def unload_if_low_vram(model):
|
||||
return model.cpu()
|
||||
return model
|
||||
|
||||
def get_torch_device():
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
return directml_device
|
||||
if vram_state == VRAMState.MPS:
|
||||
return torch.device("mps")
|
||||
if vram_state == VRAMState.CPU:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
if xpu_available:
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_autocast_device(dev):
|
||||
if hasattr(dev, 'type'):
|
||||
return dev.type
|
||||
@ -272,8 +285,17 @@ def xformers_enabled_vae():
|
||||
return XFORMERS_ENABLED_VAE
|
||||
|
||||
def pytorch_attention_enabled():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
return ENABLE_PYTORCH_ATTENTION
|
||||
|
||||
def pytorch_attention_flash_attention():
|
||||
global ENABLE_PYTORCH_ATTENTION
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
#TODO: more reliable way of checking for flash attention?
|
||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global xpu_available
|
||||
global directml_enabled
|
||||
@ -309,7 +331,12 @@ def maximum_batch_area():
|
||||
return 0
|
||||
|
||||
memory_free = get_free_memory() / (1024 * 1024)
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = 20 * memory_free
|
||||
else:
|
||||
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
|
||||
@ -2,17 +2,26 @@ import torch
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def prepare_noise(latent_image, seed, skip=0):
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
for _ in range(skip):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return noise
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
"""ensures noise mask is of proper dimensions"""
|
||||
|
||||
@ -6,6 +6,10 @@ import contextlib
|
||||
from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||
return False
|
||||
s1 = c1['c_crossattn'].shape
|
||||
s2 = c2['c_crossattn'].shape
|
||||
if s1 != s2:
|
||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||
return False
|
||||
|
||||
mult_min = lcm(s1[1], s2[1])
|
||||
diff = mult_min // min(s1[1], s2[1])
|
||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
c = x['c_crossattn']
|
||||
if crossattn_max_len == 0:
|
||||
crossattn_max_len = c.shape[1]
|
||||
else:
|
||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||
c_crossattn.append(c)
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
if 'c_adm' in x:
|
||||
c_adm.append(x['c_adm'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
c_crossattn_out = []
|
||||
for c in c_crossattn:
|
||||
if c.shape[1] < crossattn_max_len:
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
if len(c_adm) > 0:
|
||||
@ -362,19 +386,8 @@ def resolve_cond_masks(conditions, h, w, device):
|
||||
else:
|
||||
box = boxes[0]
|
||||
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||
# Make sure the height and width are divisible by 8
|
||||
if X % 8 != 0:
|
||||
newx = X // 8 * 8
|
||||
W = W + (X - newx)
|
||||
X = newx
|
||||
if Y % 8 != 0:
|
||||
newy = Y // 8 * 8
|
||||
H = H + (Y - newy)
|
||||
Y = newy
|
||||
if H % 8 != 0:
|
||||
H = H + (8 - (H % 8))
|
||||
if W % 8 != 0:
|
||||
W = W + (8 - (W % 8))
|
||||
H = max(8, H)
|
||||
W = max(8, W)
|
||||
area = (int(H), int(W), int(Y), int(X))
|
||||
modified['area'] = area
|
||||
|
||||
@ -482,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -519,6 +532,8 @@ class KSampler:
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
elif self.scheduler == "simple":
|
||||
|
||||
43
comfy/sd.py
43
comfy/sd.py
@ -2,8 +2,8 @@ import torch
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
import sd1_clip
|
||||
import sd2_clip
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from comfy import model_management
|
||||
from .ldm.util import instantiate_from_config
|
||||
from .ldm.models.autoencoder import AutoencoderKL
|
||||
@ -446,10 +446,10 @@ class CLIP:
|
||||
else:
|
||||
params = {}
|
||||
|
||||
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
|
||||
if self.target_clip.endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip = sd2_clip.SD2ClipModel
|
||||
tokenizer = sd2_clip.SD2Tokenizer
|
||||
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
|
||||
elif self.target_clip.endswith("FrozenCLIPEmbedder"):
|
||||
clip = sd1_clip.SD1ClipModel
|
||||
tokenizer = sd1_clip.SD1Tokenizer
|
||||
|
||||
@ -581,12 +581,9 @@ class VAE:
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
||||
target_batch_size = target_latent_tensor.shape[0]
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
print(current_batch_size, target_batch_size)
|
||||
#print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@ -623,7 +620,9 @@ class ControlNet:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
if self.control_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
@ -794,10 +793,14 @@ class T2IAdapter:
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.control_input = None
|
||||
self.cond_hint = None
|
||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint)
|
||||
self.t2i_model.cpu()
|
||||
@ -896,9 +899,9 @@ def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = utils.load_torch_file(ckpt_path)
|
||||
config = {}
|
||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
else:
|
||||
config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip = CLIP(config=config, embedding_directory=embedding_directory)
|
||||
clip.load_from_state_dict(clip_data)
|
||||
return clip
|
||||
@ -974,9 +977,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clip:
|
||||
clip_config = {}
|
||||
if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||
else:
|
||||
clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
load_state_dict_to = [w]
|
||||
@ -997,7 +1000,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
|
||||
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
|
||||
params["noise_schedule_config"] = noise_schedule_config
|
||||
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||
noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||
if size == 1280: #h
|
||||
params["timestep_dim"] = 1024
|
||||
elif size == 1024: #l
|
||||
@ -1049,19 +1052,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
|
||||
|
||||
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||
|
||||
if noise_aug_config is not None: #SD2.x unclip model
|
||||
sd_config["noise_aug_config"] = noise_aug_config
|
||||
sd_config["image_size"] = 96
|
||||
sd_config["embedding_dropout"] = 0.25
|
||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||
elif unet_config["in_channels"] > 4: #inpainting model
|
||||
sd_config["conditioning_key"] = "hybrid"
|
||||
sd_config["finetune_keys"] = None
|
||||
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
else:
|
||||
sd_config["conditioning_key"] = "crossattn"
|
||||
|
||||
|
||||
@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path):
|
||||
del embed
|
||||
return out
|
||||
|
||||
def expand_directory_list(directories):
|
||||
dirs = set()
|
||||
for x in directories:
|
||||
dirs.add(x)
|
||||
for root, subdir, file in os.walk(x, followlinks=True):
|
||||
dirs.add(root)
|
||||
return list(dirs)
|
||||
|
||||
def load_embed(embedding_name, embedding_directory):
|
||||
if isinstance(embedding_directory, str):
|
||||
embedding_directory = [embedding_directory]
|
||||
|
||||
embedding_directory = expand_directory_list(embedding_directory)
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
|
||||
@ -56,7 +56,12 @@ class Downsample(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
if not self.use_conv:
|
||||
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
||||
self.op.padding = padding
|
||||
|
||||
x = self.op(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
@ -46,6 +46,72 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
#slow and inefficient, should be optimized
|
||||
def bislerp(samples, width, height):
|
||||
shape = list(samples.shape)
|
||||
width_scale = (shape[3]) / (width )
|
||||
height_scale = (shape[2]) / (height )
|
||||
|
||||
shape[3] = width
|
||||
shape[2] = height
|
||||
out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device)
|
||||
|
||||
def algorithm(in1, in2, t):
|
||||
dims = in1.shape
|
||||
val = t
|
||||
|
||||
#flatten to batches
|
||||
low = in1.reshape(dims[0], -1)
|
||||
high = in2.reshape(dims[0], -1)
|
||||
|
||||
low_weight = torch.norm(low, dim=1, keepdim=True)
|
||||
low_weight[low_weight == 0] = 0.0000000001
|
||||
low_norm = low/low_weight
|
||||
high_weight = torch.norm(high, dim=1, keepdim=True)
|
||||
high_weight[high_weight == 0] = 0.0000000001
|
||||
high_norm = high/high_weight
|
||||
|
||||
dot_prod = (low_norm*high_norm).sum(1)
|
||||
dot_prod[dot_prod > 0.9995] = 0.9995
|
||||
dot_prod[dot_prod < -0.9995] = -0.9995
|
||||
omega = torch.acos(dot_prod)
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm
|
||||
res *= (low_weight * (1.0-val) + high_weight * val)
|
||||
return res.reshape(dims)
|
||||
|
||||
for x_dest in range(shape[3]):
|
||||
for y_dest in range(shape[2]):
|
||||
y = (y_dest + 0.5) * height_scale - 0.5
|
||||
x = (x_dest + 0.5) * width_scale - 0.5
|
||||
|
||||
x1 = max(math.floor(x), 0)
|
||||
x2 = min(x1 + 1, samples.shape[3] - 1)
|
||||
wx = x - math.floor(x)
|
||||
|
||||
y1 = max(math.floor(y), 0)
|
||||
y2 = min(y1 + 1, samples.shape[2] - 1)
|
||||
wy = y - math.floor(y)
|
||||
|
||||
in1 = samples[:,:,y1,x1]
|
||||
in2 = samples[:,:,y1,x2]
|
||||
in3 = samples[:,:,y2,x1]
|
||||
in4 = samples[:,:,y2,x2]
|
||||
|
||||
if (x1 == x2) and (y1 == y2):
|
||||
out_value = in1
|
||||
elif (x1 == x2):
|
||||
out_value = algorithm(in1, in3, wy)
|
||||
elif (y1 == y2):
|
||||
out_value = algorithm(in1, in2, wx)
|
||||
else:
|
||||
o1 = algorithm(in1, in2, wx)
|
||||
o2 = algorithm(in3, in4, wx)
|
||||
out_value = algorithm(o1, o2, wy)
|
||||
|
||||
out1[:,:,y_dest,x_dest] = out_value
|
||||
return out1
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
@ -61,7 +127,11 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
if upscale_method == "bislerp":
|
||||
return bislerp(s, width, height)
|
||||
else:
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CA_layer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(CA_layer, self).__init__()
|
||||
# global average pooling
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(self.gap(x))
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class Simple_CA_layer(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super(Simple_CA_layer, self).__init__()
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.fc(self.gap(x))
|
||||
|
||||
|
||||
class ECA_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ECA_MaxPool_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
super(ECA_MaxPool_layer, self).__init__()
|
||||
|
||||
b = 1
|
||||
gamma = 2
|
||||
k_size = int(abs(math.log(channel, 2) + b) / gamma)
|
||||
k_size = k_size if k_size % 2 else k_size + 1
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.conv = nn.Conv1d(
|
||||
1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
|
||||
)
|
||||
# self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
# b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.max_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
# y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
201
comfy_extras/chainner_models/architecture/OmniSR/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSA.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch import einsum, nn
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, length=1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class Conv_PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Gated_Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = int(dim * mult)
|
||||
|
||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=hidden_features * 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
# MBConv
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate=0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce("b c h w -> b c", "mean"),
|
||||
nn.Linear(dim, hidden_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias=False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange("b c -> b c 1 1"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob=0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0.0 or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = (
|
||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||
> self.prob
|
||||
)
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
|
||||
def MBConv(
|
||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(
|
||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||
),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
# nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout=dropout)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# attention related classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
if self.with_pe:
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, "c i j -> (i j) c")
|
||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||
grid, "j ... -> 1 j ..."
|
||||
)
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = (
|
||||
*x.shape,
|
||||
x.device,
|
||||
self.heads,
|
||||
)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# add positional bias
|
||||
if self.with_pe:
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(
|
||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||
)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||
|
||||
|
||||
class Block_Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.ps = window_size
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
# project for queries, keys, values
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||
h=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# attention
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(
|
||||
out,
|
||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||
x=h // self.ps,
|
||||
y=w // self.ps,
|
||||
head=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
)
|
||||
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention_grid(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention_grid, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class OSA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
ffn_bias=True,
|
||||
window_size=8,
|
||||
with_pe=False,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(OSA_Block, self).__init__()
|
||||
|
||||
w = window_size
|
||||
|
||||
self.layer = nn.Sequential(
|
||||
MBConv(
|
||||
channel_num,
|
||||
channel_num,
|
||||
downsample=False,
|
||||
expansion_rate=1,
|
||||
shrinkage_rate=0.25,
|
||||
),
|
||||
Rearrange(
|
||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # block-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
Rearrange(
|
||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # grid-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention_grid(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return out
|
||||
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
60
comfy_extras/chainner_models/architecture/OmniSR/OSAG.py
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSAG.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:08:49 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .esa import ESA
|
||||
from .OSA import OSA_Block
|
||||
|
||||
|
||||
class OSAG(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
block_num=4,
|
||||
ffn_bias=False,
|
||||
window_size=0,
|
||||
pe=False,
|
||||
):
|
||||
super(OSAG, self).__init__()
|
||||
|
||||
# print("window_size: %d" % (window_size))
|
||||
# print("with_pe", pe)
|
||||
# print("ffn_bias: %d" % (ffn_bias))
|
||||
|
||||
# block_script_name = kwargs.get("block_script_name", "OSA")
|
||||
# block_class_name = kwargs.get("block_class_name", "OSA_Block")
|
||||
|
||||
# script_name = "." + block_script_name
|
||||
# package = __import__(script_name, fromlist=True)
|
||||
block_class = OSA_Block # getattr(package, block_class_name)
|
||||
group_list = []
|
||||
for _ in range(block_num):
|
||||
temp_res = block_class(
|
||||
channel_num,
|
||||
bias,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=window_size,
|
||||
with_pe=pe,
|
||||
)
|
||||
group_list.append(temp_res)
|
||||
group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
|
||||
self.residual_layer = nn.Sequential(*group_list)
|
||||
esa_channel = max(channel_num // 4, 16)
|
||||
self.esa = ESA(esa_channel, channel_num)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.residual_layer(x)
|
||||
out = out + x
|
||||
return self.esa(out)
|
||||
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
133
comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OmniSR.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .OSAG import OSAG
|
||||
from .pixelshuffle import pixelshuffle_block
|
||||
|
||||
|
||||
class OmniSR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
**kwargs,
|
||||
):
|
||||
super(OmniSR, self).__init__()
|
||||
self.state = state_dict
|
||||
|
||||
bias = True # Fine to assume this for now
|
||||
block_num = 1 # Fine to assume this for now
|
||||
ffn_bias = True
|
||||
pe = True
|
||||
|
||||
num_feat = state_dict["input.weight"].shape[0] or 64
|
||||
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
||||
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
||||
|
||||
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
||||
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
||||
if up_scale - int(up_scale) > 0:
|
||||
print(
|
||||
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
||||
)
|
||||
up_scale = int(up_scale)
|
||||
res_num = 0
|
||||
for key in state_dict.keys():
|
||||
if "residual_layer" in key:
|
||||
temp_res_num = int(key.split(".")[1])
|
||||
if temp_res_num > res_num:
|
||||
res_num = temp_res_num
|
||||
res_num = res_num + 1 # zero-indexed
|
||||
|
||||
residual_layer = []
|
||||
self.res_num = res_num
|
||||
|
||||
self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
|
||||
self.up_scale = up_scale
|
||||
|
||||
for _ in range(res_num):
|
||||
temp_res = OSAG(
|
||||
channel_num=num_feat,
|
||||
bias=bias,
|
||||
block_num=block_num,
|
||||
ffn_bias=ffn_bias,
|
||||
window_size=self.window_size,
|
||||
pe=pe,
|
||||
)
|
||||
residual_layer.append(temp_res)
|
||||
self.residual_layer = nn.Sequential(*residual_layer)
|
||||
self.input = nn.Conv2d(
|
||||
in_channels=num_in_ch,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.output = nn.Conv2d(
|
||||
in_channels=num_feat,
|
||||
out_channels=num_feat,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
||||
|
||||
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
# m.weight.data.normal_(0, sqrt(2. / n))
|
||||
|
||||
# chaiNNer specific stuff
|
||||
self.model_arch = "OmniSR"
|
||||
self.sub_type = "SR"
|
||||
self.in_nc = num_in_ch
|
||||
self.out_nc = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.scale = up_scale
|
||||
|
||||
self.supports_fp16 = True # TODO: Test this
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = 16
|
||||
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
# import pdb; pdb.set_trace()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
residual = self.input(x)
|
||||
out = self.residual_layer(residual)
|
||||
|
||||
# origin
|
||||
out = torch.add(self.output(out), residual)
|
||||
out = self.up(out)
|
||||
|
||||
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
||||
return out
|
||||
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
294
comfy_extras/chainner_models/architecture/OmniSR/esa.py
Normal file
@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: esa.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:06 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
|
||||
def moment(x, dim=(2, 3), k=2):
|
||||
assert len(x.size()) == 4
|
||||
mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
|
||||
mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
|
||||
return mk
|
||||
|
||||
|
||||
class ESA(nn.Module):
|
||||
"""
|
||||
Modification of Enhanced Spatial Attention (ESA), which is proposed by
|
||||
`Residual Feature Aggregation Network for Image Super-Resolution`
|
||||
Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
|
||||
are deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
|
||||
super(ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
||||
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
c1 = self.conv2(c1_)
|
||||
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
||||
c3 = self.conv3(v_max)
|
||||
c3 = F.interpolate(
|
||||
c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
|
||||
)
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(c3 + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.conv1(x)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class LK_ESA_LN(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(LK_ESA_LN, self).__init__()
|
||||
f = esa_channels
|
||||
self.conv1 = conv(n_feats, f, kernel_size=1)
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.norm = LayerNorm2d(n_feats)
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.vec_conv3x1 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(1, 3),
|
||||
padding=(0, 1),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
self.hor_conv1x3 = nn.Conv2d(
|
||||
in_channels=f * kernel_expand,
|
||||
out_channels=f * kernel_expand,
|
||||
kernel_size=(3, 1),
|
||||
padding=(1, 0),
|
||||
groups=2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.conv4 = conv(f, n_feats, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
c1_ = self.norm(x)
|
||||
c1_ = self.conv1(c1_)
|
||||
|
||||
res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
|
||||
res = self.hor_conv(res) + self.hor_conv1x3(res)
|
||||
|
||||
cf = self.conv_f(c1_)
|
||||
c4 = self.conv4(res + cf)
|
||||
m = self.sigmoid(c4)
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaGuidedFilter, self).__init__()
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=n_feats,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.r = 5
|
||||
|
||||
def box_filter(self, x, r):
|
||||
channel = x.shape[1]
|
||||
kernel_size = 2 * r + 1
|
||||
weight = 1.0 / (kernel_size**2)
|
||||
box_kernel = weight * torch.ones(
|
||||
(channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
|
||||
)
|
||||
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
|
||||
return output
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
N = self.box_filter(
|
||||
torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
|
||||
)
|
||||
|
||||
# epsilon = self.fc(self.gap(x))
|
||||
# epsilon = torch.pow(epsilon, 2)
|
||||
epsilon = 1e-2
|
||||
|
||||
mean_x = self.box_filter(x, self.r) / N
|
||||
var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
|
||||
|
||||
A = var_x / (var_x + epsilon)
|
||||
b = (1 - A) * mean_x
|
||||
m = A * x + b
|
||||
|
||||
# mean_A = self.box_filter(A, self.r) / N
|
||||
# mean_b = self.box_filter(b, self.r) / N
|
||||
# m = mean_A * x + mean_b
|
||||
return x * m
|
||||
|
||||
|
||||
class AdaConvGuidedFilter(nn.Module):
|
||||
def __init__(
|
||||
self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
|
||||
):
|
||||
super(AdaConvGuidedFilter, self).__init__()
|
||||
f = esa_channels
|
||||
|
||||
self.conv_f = conv(f, f, kernel_size=1)
|
||||
|
||||
kernel_size = 17
|
||||
kernel_expand = kernel_expand
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.vec_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(1, kernel_size),
|
||||
padding=(0, padding),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.hor_conv = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=(kernel_size, 1),
|
||||
padding=(padding, 0),
|
||||
groups=f,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(
|
||||
in_channels=f,
|
||||
out_channels=f,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.vec_conv(x)
|
||||
y = self.hor_conv(y)
|
||||
|
||||
sigma = torch.pow(y, 2)
|
||||
epsilon = self.fc(self.gap(y))
|
||||
|
||||
weight = sigma / (sigma + epsilon)
|
||||
|
||||
m = weight * x + (1 - weight)
|
||||
|
||||
return x * m
|
||||
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: layernorm.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 20th April 2023 9:28:20 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return (
|
||||
gx,
|
||||
(grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
"""GRN (Global Response Normalization) layer"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: pixelshuffle.py
|
||||
# Created Date: Friday July 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||
):
|
||||
"""
|
||||
Upsample features according to `upscale_factor`.
|
||||
"""
|
||||
padding = kernel_size // 2
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * (upscale_factor**2),
|
||||
kernel_size,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
return nn.Sequential(*[conv, pixel_shuffle])
|
||||
@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
|
||||
self.scale: int = self.get_scale()
|
||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||
|
||||
c2x2 = False
|
||||
if self.state["model.0.weight"].shape[-2] == 2:
|
||||
c2x2 = True
|
||||
self.scale = round(math.sqrt(self.scale / 4))
|
||||
self.model_arch = "ESRGAN-2c2"
|
||||
|
||||
self.supports_fp16 = True
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = None
|
||||
@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
|
||||
out_nc=self.num_filters,
|
||||
upscale_factor=3,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
else:
|
||||
upsample_blocks = [
|
||||
upsample_block(
|
||||
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(int(math.log(self.scale, 2)))
|
||||
]
|
||||
@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
B.ShortcutBlock(
|
||||
B.sequential(
|
||||
@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
|
||||
act_type=self.act,
|
||||
mode="CNA",
|
||||
plus=self.plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(self.num_blocks)
|
||||
],
|
||||
@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
|
||||
norm_type=self.norm,
|
||||
act_type=None,
|
||||
mode=self.mode,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
),
|
||||
@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
# hr_conv1
|
||||
B.conv_block(
|
||||
@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -141,6 +141,19 @@ def sequential(*args):
|
||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||
|
||||
|
||||
# 2x2x2 Conv Block
|
||||
def conv_block_2c2(
|
||||
in_nc,
|
||||
out_nc,
|
||||
act_type="relu",
|
||||
):
|
||||
return sequential(
|
||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||
act(act_type) if act_type else None,
|
||||
)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
@ -153,12 +166,17 @@ def conv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type: str | None = "relu",
|
||||
mode: ConvMode = "CNA",
|
||||
c2x2=False,
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
|
||||
if c2x2:
|
||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||
|
||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||
@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
||||
_convtype="Conv2D",
|
||||
_spectral_norm=False,
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(
|
||||
@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB2 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB3 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nf + gc,
|
||||
@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nf + 2 * gc,
|
||||
@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nf + 3 * gc,
|
||||
@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=last_act,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -499,6 +527,7 @@ def upconv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
mode="nearest",
|
||||
c2x2=False,
|
||||
):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
@ -512,5 +541,6 @@ def upconv_block(
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
state_dict = state_dict["params"]
|
||||
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
# SRVGGNet Real-ESRGAN (v2)
|
||||
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
||||
model = RealESRGANv2(state_dict)
|
||||
@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel:
|
||||
# MAT
|
||||
elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys:
|
||||
model = MAT(state_dict)
|
||||
# Omni-SR
|
||||
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
||||
model = OmniSR(state_dict)
|
||||
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
||||
else:
|
||||
try:
|
||||
|
||||
@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer
|
||||
from .architecture.HAT import HAT
|
||||
from .architecture.LaMa import LaMa
|
||||
from .architecture.MAT import MAT
|
||||
from .architecture.OmniSR.OmniSR import OmniSR
|
||||
from .architecture.RRDB import RRDBNet as ESRGAN
|
||||
from .architecture.SPSR import SPSRNet as SPSR
|
||||
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
||||
@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
||||
from .architecture.Swin2SR import Swin2SR
|
||||
from .architecture.SwinIR import SwinIR
|
||||
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT)
|
||||
PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR)
|
||||
PyTorchSRModel = Union[
|
||||
RealESRGANv2,
|
||||
SPSR,
|
||||
@ -22,6 +23,7 @@ PyTorchSRModel = Union[
|
||||
SwinIR,
|
||||
Swin2SR,
|
||||
HAT,
|
||||
OmniSR,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength):
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
"softsign": torch.nn.Softsign,
|
||||
}
|
||||
|
||||
if activation_func not in valid_activation:
|
||||
|
||||
@ -72,7 +72,7 @@ class MaskToImage:
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
|
||||
class ImageToMask:
|
||||
|
||||
@ -59,6 +59,12 @@ class Blend:
|
||||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
class Blur:
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -88,12 +94,6 @@ class Blur:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||
if blur_radius == 0:
|
||||
return (image,)
|
||||
@ -101,10 +101,11 @@ class Blur:
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||
blurred = blurred.permute(0, 2, 3, 1)
|
||||
|
||||
return (blurred,)
|
||||
@ -167,9 +168,15 @@ class Sharpen:
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
@ -181,21 +188,21 @@ class Sharpen:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||
if sharpen_radius == 0:
|
||||
return (image,)
|
||||
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
||||
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel_size**2
|
||||
kernel *= alpha
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
||||
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||
|
||||
result = torch.clamp(sharpened, 0, 1)
|
||||
|
||||
108
comfy_extras/nodes_rebatch.py
Normal file
108
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,108 @@
|
||||
import torch
|
||||
|
||||
class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
@staticmethod
|
||||
def get_batch(latents, list_ind, offset):
|
||||
'''prepare a batch out of the list of latents'''
|
||||
samples = latents[list_ind]['samples']
|
||||
shape = samples.shape
|
||||
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||
if mask.shape[0] < samples.shape[0]:
|
||||
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||
if 'batch_index' in latents[list_ind]:
|
||||
batch_inds = latents[list_ind]['batch_index']
|
||||
else:
|
||||
batch_inds = [x+offset for x in range(shape[0])]
|
||||
return samples, mask, batch_inds
|
||||
|
||||
@staticmethod
|
||||
def get_slices(indexable, num, batch_size):
|
||||
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||
slices = []
|
||||
for i in range(num):
|
||||
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||
if num * batch_size < len(indexable):
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
return list(zip(*result))
|
||||
|
||||
@staticmethod
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
current_batch = (None, None, None)
|
||||
processed = 0
|
||||
|
||||
for i in range(len(latents)):
|
||||
# fetch new entry of list
|
||||
#samples, masks, indices = self.get_batch(latents, i)
|
||||
next_batch = self.get_batch(latents, i, processed)
|
||||
processed += len(next_batch[2])
|
||||
# set to current if current is None
|
||||
if current_batch[0] is None:
|
||||
current_batch = next_batch
|
||||
# add previous to list if dimensions do not match
|
||||
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
current_batch = next_batch
|
||||
# cat if everything checks out
|
||||
else:
|
||||
current_batch = self.cat_batch(current_batch, next_batch)
|
||||
|
||||
# add to list if dimensions gone above target batch size
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
current_batch = remainder
|
||||
|
||||
#add remainder
|
||||
if current_batch[0] is not None:
|
||||
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||
|
||||
#get rid of empty masks
|
||||
for s in output_list:
|
||||
if s['noise_mask'].mean() == 1.0:
|
||||
del s['noise_mask']
|
||||
|
||||
return (output_list,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RebatchLatents": LatentRebatch,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
||||
@ -17,7 +17,7 @@ class UpscaleModelLoader:
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
out = model_loading.load_state_dict(sd).eval()
|
||||
return (out, )
|
||||
|
||||
|
||||
188
execution.py
188
execution.py
@ -6,6 +6,7 @@ import threading
|
||||
import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
@ -26,21 +27,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = input_data
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = prompt
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = extra_data['extra_pnginfo']
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = unique_id
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
results = []
|
||||
if intput_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
@ -55,21 +117,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id]:
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
@ -105,7 +166,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
@ -144,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
self.server = server
|
||||
|
||||
def execute(self, prompt, extra_data={}):
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
@ -155,6 +218,10 @@ class PromptExecutor:
|
||||
else:
|
||||
self.server.client_id = None
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
@ -169,32 +236,34 @@ class PromptExecutor:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in prompt:
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
to_execute += [(0, x)]
|
||||
for x in list(execute_outputs):
|
||||
to_execute += [(0, x)]
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
x = to_execute.pop(0)[-1]
|
||||
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
if hasattr(class_, 'OUTPUT_NODE'):
|
||||
if class_.OUTPUT_NODE == True:
|
||||
valid = False
|
||||
try:
|
||||
m = validate_inputs(prompt, x)
|
||||
valid = m[0]
|
||||
except:
|
||||
valid = False
|
||||
if valid:
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||
print("Processing interrupted")
|
||||
else:
|
||||
message = str(traceback.format_exc())
|
||||
print(message)
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
@ -210,14 +279,18 @@ class PromptExecutor:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
|
||||
def validate_inputs(prompt, item):
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
@ -226,20 +299,21 @@ def validate_inputs(prompt, item):
|
||||
required_inputs = class_inputs['required']
|
||||
for x in required_inputs:
|
||||
if x not in inputs:
|
||||
return (False, "Required input is missing. {}, {}".format(class_type, x))
|
||||
return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id)
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
return (False, "Bad Input. {}, {}".format(class_type, x))
|
||||
return (False, "Bad Input. {}, {}".format(class_type, x), unique_id)
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
if r[val[1]] != type_input:
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||
r = validate_inputs(prompt, o_id)
|
||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id)
|
||||
r = validate_inputs(prompt, o_id, validated)
|
||||
if r[0] == False:
|
||||
validated[o_id] = r
|
||||
return r
|
||||
else:
|
||||
if type_input == "INT":
|
||||
@ -254,20 +328,25 @@ def validate_inputs(prompt, item):
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
return (False, "Value smaller than min. {}, {}".format(class_type, x))
|
||||
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id)
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id)
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
if ret != True:
|
||||
return (False, "{}, {}".format(class_type, ret))
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
return (False, "{}, {}".format(class_type, r), unique_id)
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
return (True, "")
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id)
|
||||
|
||||
ret = (True, "", unique_id)
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
def validate_prompt(prompt):
|
||||
outputs = set()
|
||||
@ -277,34 +356,42 @@ def validate_prompt(prompt):
|
||||
outputs.add(x)
|
||||
|
||||
if len(outputs) == 0:
|
||||
return (False, "Prompt has no outputs")
|
||||
return (False, "Prompt has no outputs", [], [])
|
||||
|
||||
good_outputs = set()
|
||||
errors = []
|
||||
node_errors = {}
|
||||
validated = {}
|
||||
for o in outputs:
|
||||
valid = False
|
||||
reason = ""
|
||||
try:
|
||||
m = validate_inputs(prompt, o)
|
||||
m = validate_inputs(prompt, o, validated)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
node_id = m[2]
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
node_id = None
|
||||
|
||||
if valid == True:
|
||||
good_outputs.add(x)
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||
print("output will be ignored")
|
||||
errors += [(o, reason)]
|
||||
if node_id is not None:
|
||||
if node_id not in node_errors:
|
||||
node_errors[node_id] = {"message": reason, "dependent_outputs": []}
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors)
|
||||
|
||||
return (True, "")
|
||||
return (True, "", list(good_outputs), node_errors)
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
@ -340,8 +427,7 @@ class PromptQueue:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
if "ui" in outputs[o]:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
||||
@ -147,4 +147,37 @@ def get_filename_list(folder_name):
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input, image_width, image_height):
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
17
main.py
17
main.py
@ -33,8 +33,8 @@ def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
e.execute(item[-2], item[-1])
|
||||
q.task_done(item_id, e.outputs)
|
||||
e.execute(item[2], item[1], item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
@ -91,23 +91,16 @@ if __name__ == "__main__":
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
|
||||
|
||||
address = args.listen
|
||||
|
||||
dont_print = args.dont_print_server
|
||||
|
||||
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
print(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
port = args.port
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
call_on_start = None
|
||||
if args.windows_standalone_build:
|
||||
if args.auto_launch:
|
||||
def startup_server(address, port):
|
||||
import webbrowser
|
||||
webbrowser.open("http://{}:{}".format(address, port))
|
||||
@ -115,10 +108,10 @@ if __name__ == "__main__":
|
||||
|
||||
if os.name == "nt":
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
|
||||
cleanup_temp()
|
||||
|
||||
261
nodes.py
261
nodes.py
@ -6,10 +6,12 @@ import json
|
||||
import hashlib
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
@ -28,6 +30,7 @@ import importlib
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
@ -105,15 +108,13 @@ class ConditioningSetArea:
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
def append(self, conditioning, width, height, x, y, strength):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
n[1]['min_sigma'] = min_sigma
|
||||
n[1]['max_sigma'] = max_sigma
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
@ -147,9 +148,6 @@ class ConditioningSetMask:
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -162,9 +160,6 @@ class VAEDecode:
|
||||
return (vae.decode(samples["samples"]), )
|
||||
|
||||
class VAEDecodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
@ -177,9 +172,6 @@ class VAEDecodeTiled:
|
||||
return (vae.decode_tiled(samples["samples"]), )
|
||||
|
||||
class VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -204,9 +196,6 @@ class VAEEncode:
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeTiled:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||
@ -221,9 +210,6 @@ class VAEEncodeTiled:
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
||||
@ -262,6 +248,81 @@ class VAEEncodeForInpaint:
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ),
|
||||
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
# support save metadata for latent sharing
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
file = f"{filename}_{counter:05}_.latent"
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
samples = {"samples": latent["latent_tensor"].float()}
|
||||
return (samples, )
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, latent):
|
||||
image_path = folder_paths.get_annotated_filepath(latent)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, latent):
|
||||
if not folder_paths.exists_annotated_filepath(latent):
|
||||
return "Invalid latent file: {}".format(latent)
|
||||
return True
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -298,7 +359,10 @@ class DiffusersLoader:
|
||||
paths = []
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths += next(os.walk(search_path))[1]
|
||||
for root, subdir, files in os.walk(search_path, followlinks=True):
|
||||
if "model_index.json" in files:
|
||||
paths.append(os.path.relpath(root, start=search_path))
|
||||
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -308,9 +372,9 @@ class DiffusersLoader:
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
paths = next(os.walk(search_path))[1]
|
||||
if model_path in paths:
|
||||
model_path = os.path.join(search_path, model_path)
|
||||
path = os.path.join(search_path, model_path)
|
||||
if os.path.exists(path):
|
||||
model_path = path
|
||||
break
|
||||
|
||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
@ -445,7 +509,6 @@ class ControlNetApply:
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
print(control_hint.shape)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
||||
@ -632,22 +695,61 @@ class LatentFromBatch:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "rotate"
|
||||
FUNCTION = "frombatch"
|
||||
|
||||
CATEGORY = "latent"
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def rotate(self, samples, batch_index):
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
||||
s["batch_index"] = batch_index
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] == 1:
|
||||
s["noise_mask"] = masks.clone()
|
||||
else:
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||
if "batch_index" not in s:
|
||||
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
if masks.shape[0] < s_in.shape[0]:
|
||||
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||
if "batch_index" in s:
|
||||
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||
return (s,)
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@classmethod
|
||||
@ -666,6 +768,25 @@ class LatentUpscale:
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
return (s,)
|
||||
|
||||
class LatentUpscaleBy:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
|
||||
"scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, scale_by):
|
||||
s = samples.copy()
|
||||
width = round(samples["samples"].shape[3] * scale_by)
|
||||
height = round(samples["samples"].shape[2] * scale_by)
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
|
||||
return (s,)
|
||||
|
||||
class LatentRotate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -798,7 +919,7 @@ class SetLatentNoiseMask:
|
||||
|
||||
def set_mask(self, samples, mask):
|
||||
s = samples.copy()
|
||||
s["noise_mask"] = mask
|
||||
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||
return (s,)
|
||||
|
||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
@ -808,8 +929,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
@ -904,39 +1025,7 @@ class SaveImage:
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input):
|
||||
input = input.replace("%width%", str(images[0].shape[1]))
|
||||
input = input.replace("%height%", str(images[0].shape[0]))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
@ -975,8 +1064,9 @@ class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), )},
|
||||
{"image": (sorted(files), )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -986,6 +1076,7 @@ class LoadImage:
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
@ -1016,9 +1107,10 @@ class LoadImageMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), ),
|
||||
"channel": (s._color_channels, ),}
|
||||
{"image": (sorted(files), ),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
@ -1028,6 +1120,7 @@ class LoadImageMask:
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
@ -1170,7 +1263,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"VAELoader": VAELoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"LatentUpscaleBy": LatentUpscaleBy,
|
||||
"LatentFromBatch": LatentFromBatch,
|
||||
"RepeatLatentBatch": RepeatLatentBatch,
|
||||
"SaveImage": SaveImage,
|
||||
"PreviewImage": PreviewImage,
|
||||
"LoadImage": LoadImage,
|
||||
@ -1207,6 +1302,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1244,7 +1342,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentCrop": "Crop Latent",
|
||||
"EmptyLatentImage": "Empty Latent Image",
|
||||
"LatentUpscale": "Upscale Latent",
|
||||
"LatentUpscaleBy": "Upscale Latent By",
|
||||
"LatentComposite": "Latent Composite",
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
"SaveImage": "Save Image",
|
||||
"PreviewImage": "Preview Image",
|
||||
@ -1276,14 +1377,18 @@ def load_custom_node(module_path):
|
||||
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
else:
|
||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||
return False
|
||||
|
||||
def load_custom_nodes():
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
if "__pycache__" in possible_modules:
|
||||
@ -1292,11 +1397,25 @@ def load_custom_nodes():
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
load_custom_node(module_path)
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path)
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
print("\nImport times for custom nodes:")
|
||||
for n in sorted(node_import_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (IMPORT FAILED)"
|
||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
print()
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_nodes()
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
load_custom_nodes()
|
||||
|
||||
@ -175,6 +175,8 @@
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
@ -183,7 +185,9 @@
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end='')\n",
|
||||
|
||||
203
server.py
203
server.py
@ -7,6 +7,9 @@ import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
@ -78,7 +81,7 @@ class PromptServer():
|
||||
# Reusing existing session, remove old
|
||||
self.sockets.pop(sid, None)
|
||||
else:
|
||||
sid = uuid.uuid4().hex
|
||||
sid = uuid.uuid4().hex
|
||||
|
||||
self.sockets[sid] = ws
|
||||
|
||||
@ -110,49 +113,96 @@ class PromptServer():
|
||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
def get_dir_by_type(dir_type):
|
||||
if dir_type is None:
|
||||
dir_type = "input"
|
||||
|
||||
if dir_type == "input":
|
||||
type_dir = folder_paths.get_input_directory()
|
||||
elif dir_type == "temp":
|
||||
type_dir = folder_paths.get_temp_directory()
|
||||
elif dir_type == "output":
|
||||
type_dir = folder_paths.get_output_directory()
|
||||
|
||||
return type_dir, dir_type
|
||||
|
||||
def image_upload(post, image_save_function=None):
|
||||
image = post.get("image")
|
||||
overwrite = post.get("overwrite")
|
||||
|
||||
if post.get("type") is None:
|
||||
upload_dir = folder_paths.get_input_directory()
|
||||
elif post.get("type") == "input":
|
||||
upload_dir = folder_paths.get_input_directory()
|
||||
elif post.get("type") == "temp":
|
||||
upload_dir = folder_paths.get_temp_directory()
|
||||
elif post.get("type") == "output":
|
||||
upload_dir = folder_paths.get_output_directory()
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
image_upload_type = post.get("type")
|
||||
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||||
|
||||
if image and image.file:
|
||||
filename = image.filename
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
subfolder = post.get("subfolder", "")
|
||||
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
|
||||
|
||||
if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir:
|
||||
return web.Response(status=400)
|
||||
|
||||
if not os.path.exists(full_output_folder):
|
||||
os.makedirs(full_output_folder)
|
||||
|
||||
split = os.path.splitext(filename)
|
||||
i = 1
|
||||
while os.path.exists(os.path.join(upload_dir, filename)):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
i += 1
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
|
||||
filepath = os.path.join(upload_dir, filename)
|
||||
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||||
pass
|
||||
else:
|
||||
i = 1
|
||||
while os.path.exists(filepath):
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
filepath = os.path.join(full_output_folder, filename)
|
||||
i += 1
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename})
|
||||
if image_save_function is not None:
|
||||
image_save_function(image, post, filepath)
|
||||
else:
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@routes.post("/upload/image")
|
||||
async def upload_image(request):
|
||||
post = await request.post()
|
||||
return image_upload(post)
|
||||
|
||||
@routes.post("/upload/mask")
|
||||
async def upload_mask(request):
|
||||
post = await request.post()
|
||||
|
||||
def image_save_function(image, post, filepath):
|
||||
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
||||
mask_pil = Image.open(image.file).convert('RGBA')
|
||||
|
||||
# alpha copy
|
||||
new_alpha = mask_pil.getchannel('A')
|
||||
original_pil.putalpha(new_alpha)
|
||||
original_pil.save(filepath, compress_level=4)
|
||||
|
||||
return image_upload(post, image_save_function)
|
||||
|
||||
@routes.get("/view")
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
@ -162,35 +212,88 @@ class PromptServer():
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
else:
|
||||
channel = request.rel_url.query["channel"]
|
||||
|
||||
if channel == 'rgb':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
r, g, b, a = img.split()
|
||||
new_img = Image.merge('RGB', (r, g, b))
|
||||
else:
|
||||
new_img = img.convert("RGB")
|
||||
|
||||
buffer = BytesIO()
|
||||
new_img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
if img.mode == "RGBA":
|
||||
_, _, _, a = img.split()
|
||||
else:
|
||||
a = Image.new('L', img.size, 255)
|
||||
|
||||
# alpha img
|
||||
alpha_img = Image.new('RGBA', img.size)
|
||||
alpha_img.putalpha(a)
|
||||
alpha_buffer = BytesIO()
|
||||
alpha_img.save(alpha_buffer, format='PNG')
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/prompt")
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
info['output_node'] = True
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = x
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
out[x] = info
|
||||
out[x] = node_info(x)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
node_class = request.match_info.get("node_class", None)
|
||||
out = {}
|
||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/history")
|
||||
@ -232,14 +335,16 @@ class PromptServer():
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
else:
|
||||
resp_code = 400
|
||||
out_string = valid[1]
|
||||
print("invalid prompt:", valid[1])
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
|
||||
return web.Response(body=out_string, status=resp_code)
|
||||
|
||||
@routes.post("/queue")
|
||||
async def post_queue(request):
|
||||
json_data = await request.json()
|
||||
@ -249,9 +354,9 @@ class PromptServer():
|
||||
if "delete" in json_data:
|
||||
to_delete = json_data['delete']
|
||||
for id_to_delete in to_delete:
|
||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
||||
delete_func = lambda a: a[1] == id_to_delete
|
||||
self.prompt_queue.delete_queue_item(delete_func)
|
||||
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/interrupt")
|
||||
@ -275,7 +380,7 @@ class PromptServer():
|
||||
def add_routes(self):
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root),
|
||||
web.static('/', self.web_root, follow_symlinks=True),
|
||||
])
|
||||
|
||||
def get_queue_info(self):
|
||||
|
||||
166
web/extensions/core/clipspace.js
Normal file
166
web/extensions/core/clipspace.js
Normal file
@ -0,0 +1,166 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
|
||||
export class ClipspaceDialog extends ComfyDialog {
|
||||
static items = [];
|
||||
static instance = null;
|
||||
|
||||
static registerButton(name, contextPredicate, callback) {
|
||||
const item =
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: name,
|
||||
contextPredicate: contextPredicate,
|
||||
onclick: callback
|
||||
})
|
||||
|
||||
ClipspaceDialog.items.push(item);
|
||||
}
|
||||
|
||||
static invalidatePreview() {
|
||||
if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
if(img_preview) {
|
||||
img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
img_preview.style.maxHeight = "100%";
|
||||
img_preview.style.maxWidth = "100%";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static invalidate() {
|
||||
if(ClipspaceDialog.instance) {
|
||||
const self = ClipspaceDialog.instance;
|
||||
// allow reconstruct controls when copying from non-image to image content.
|
||||
const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]);
|
||||
|
||||
if(self.element) {
|
||||
// update
|
||||
self.element.removeChild(self.element.firstChild);
|
||||
self.element.appendChild(children);
|
||||
}
|
||||
else {
|
||||
// new
|
||||
self.element = $el("div.comfy-modal", { parent: document.body }, [children,]);
|
||||
}
|
||||
|
||||
if(self.element.children[0].children.length <= 1) {
|
||||
self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."]));
|
||||
}
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
}
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
|
||||
createButtons(self) {
|
||||
const buttons = [];
|
||||
|
||||
for(let idx in ClipspaceDialog.items) {
|
||||
const item = ClipspaceDialog.items[idx];
|
||||
if(!item.contextPredicate || item.contextPredicate())
|
||||
buttons.push(ClipspaceDialog.items[idx]);
|
||||
}
|
||||
|
||||
buttons.push(
|
||||
$el("button", {
|
||||
type: "button",
|
||||
textContent: "Close",
|
||||
onclick: () => { this.close(); }
|
||||
})
|
||||
);
|
||||
|
||||
return buttons;
|
||||
}
|
||||
|
||||
createImgSettings() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
const combo_items = [];
|
||||
const imgs = ComfyApp.clipspace.imgs;
|
||||
|
||||
for(let i=0; i < imgs.length; i++) {
|
||||
combo_items.push($el("option", {value:i}, [`${i}`]));
|
||||
}
|
||||
|
||||
const combo1 = $el("select",
|
||||
{id:"clipspace_img_selector", onchange:(event) => {
|
||||
ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex;
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
} }, combo_items);
|
||||
|
||||
const row1 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]),
|
||||
$el("td", {}, [combo1])
|
||||
]);
|
||||
|
||||
|
||||
const combo2 = $el("select",
|
||||
{id:"clipspace_img_paste_mode", onchange:(event) => {
|
||||
ComfyApp.clipspace['img_paste_mode'] = event.target.value;
|
||||
} },
|
||||
[
|
||||
$el("option", {value:'selected'}, 'selected'),
|
||||
$el("option", {value:'all'}, 'all')
|
||||
]);
|
||||
combo2.value = ComfyApp.clipspace['img_paste_mode'];
|
||||
|
||||
const row2 =
|
||||
$el("tr", {},
|
||||
[
|
||||
$el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]),
|
||||
$el("td", {}, [combo2])
|
||||
]);
|
||||
|
||||
const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'},
|
||||
[ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]);
|
||||
|
||||
const row3 =
|
||||
$el("tr", {}, [td]);
|
||||
|
||||
return $el("table", {}, [row1, row2, row3]);
|
||||
}
|
||||
else {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
createImgPreview() {
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
return $el("img",{id:"clipspace_preview", ondragstart:() => false});
|
||||
}
|
||||
else
|
||||
return [];
|
||||
}
|
||||
|
||||
show() {
|
||||
const img_preview = document.getElementById("clipspace_preview");
|
||||
ClipspaceDialog.invalidate();
|
||||
|
||||
this.element.style.display = "block";
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.Clipspace",
|
||||
init(app) {
|
||||
app.openClipspace =
|
||||
function () {
|
||||
if(!ClipspaceDialog.instance) {
|
||||
ClipspaceDialog.instance = new ClipspaceDialog(app);
|
||||
ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace) {
|
||||
ClipspaceDialog.instance.show();
|
||||
}
|
||||
else
|
||||
app.ui.dialog.show("Clipspace is Empty!");
|
||||
};
|
||||
}
|
||||
});
|
||||
@ -174,7 +174,7 @@ const els = {}
|
||||
// const ctxMenu = LiteGraph.ContextMenu;
|
||||
app.registerExtension({
|
||||
name: id,
|
||||
init() {
|
||||
addCustomNodeDefs(node_defs) {
|
||||
const sortObjectKeys = (unordered) => {
|
||||
return Object.keys(unordered).sort().reduce((obj, key) => {
|
||||
obj[key] = unordered[key];
|
||||
@ -182,10 +182,10 @@ app.registerExtension({
|
||||
}, {});
|
||||
};
|
||||
|
||||
const getSlotTypes = async () => {
|
||||
function getSlotTypes() {
|
||||
var types = [];
|
||||
|
||||
const defs = await api.getNodeDefs();
|
||||
const defs = node_defs;
|
||||
for (const nodeId in defs) {
|
||||
const nodeData = defs[nodeId];
|
||||
|
||||
@ -212,8 +212,8 @@ app.registerExtension({
|
||||
return types;
|
||||
};
|
||||
|
||||
const completeColorPalette = async (colorPalette) => {
|
||||
var types = await getSlotTypes();
|
||||
function completeColorPalette(colorPalette) {
|
||||
var types = getSlotTypes();
|
||||
|
||||
for (const type of types) {
|
||||
if (!colorPalette.colors.node_slot[type]) {
|
||||
|
||||
648
web/extensions/core/maskeditor.js
Normal file
648
web/extensions/core/maskeditor.js
Normal file
@ -0,0 +1,648 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import { ComfyApp } from "/scripts/app.js";
|
||||
import { ClipspaceDialog } from "/extensions/core/clipspace.js";
|
||||
|
||||
// Helper function to convert a data URL to a Blob object
|
||||
function dataURLToBlob(dataURL) {
|
||||
const parts = dataURL.split(';base64,');
|
||||
const contentType = parts[0].split(':')[1];
|
||||
const byteString = atob(parts[1]);
|
||||
const arrayBuffer = new ArrayBuffer(byteString.length);
|
||||
const uint8Array = new Uint8Array(arrayBuffer);
|
||||
for (let i = 0; i < byteString.length; i++) {
|
||||
uint8Array[i] = byteString.charCodeAt(i);
|
||||
}
|
||||
return new Blob([arrayBuffer], { type: contentType });
|
||||
}
|
||||
|
||||
function loadedImageToBlob(image) {
|
||||
const canvas = document.createElement('canvas');
|
||||
|
||||
canvas.width = image.width;
|
||||
canvas.height = image.height;
|
||||
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
ctx.drawImage(image, 0, 0);
|
||||
|
||||
const dataURL = canvas.toDataURL('image/png', 1);
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
async function uploadMask(filepath, formData) {
|
||||
await fetch('/upload/mask', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
}).then(response => {}).catch(error => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
|
||||
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString();
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;
|
||||
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
|
||||
function prepareRGB(image, backupCanvas, backupCtx) {
|
||||
// paste mask data into alpha channel
|
||||
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
}
|
||||
|
||||
class MaskEditorDialog extends ComfyDialog {
|
||||
static instance = null;
|
||||
|
||||
static getInstance() {
|
||||
if(!MaskEditorDialog.instance) {
|
||||
MaskEditorDialog.instance = new MaskEditorDialog(app);
|
||||
}
|
||||
|
||||
return MaskEditorDialog.instance;
|
||||
}
|
||||
|
||||
is_layout_created = false;
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.element = $el("div.comfy-modal", { parent: document.body },
|
||||
[ $el("div.comfy-modal-content",
|
||||
[...this.createButtons()]),
|
||||
]);
|
||||
}
|
||||
|
||||
createButtons() {
|
||||
return [];
|
||||
}
|
||||
|
||||
createButton(name, callback) {
|
||||
var button = document.createElement("button");
|
||||
button.innerText = name;
|
||||
button.addEventListener("click", callback);
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "left";
|
||||
button.style.marginRight = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createRightButton(name, callback) {
|
||||
var button = this.createButton(name, callback);
|
||||
button.style.cssFloat = "right";
|
||||
button.style.marginLeft = "4px";
|
||||
return button;
|
||||
}
|
||||
|
||||
createLeftSlider(self, name, callback) {
|
||||
const divElement = document.createElement('div');
|
||||
divElement.id = "maskeditor-slider";
|
||||
divElement.style.cssFloat = "left";
|
||||
divElement.style.fontFamily = "sans-serif";
|
||||
divElement.style.marginRight = "4px";
|
||||
divElement.style.color = "var(--input-text)";
|
||||
divElement.style.backgroundColor = "var(--comfy-input-bg)";
|
||||
divElement.style.borderRadius = "8px";
|
||||
divElement.style.borderColor = "var(--border-color)";
|
||||
divElement.style.borderStyle = "solid";
|
||||
divElement.style.fontSize = "15px";
|
||||
divElement.style.height = "21px";
|
||||
divElement.style.padding = "1px 6px";
|
||||
divElement.style.display = "flex";
|
||||
divElement.style.position = "relative";
|
||||
divElement.style.top = "2px";
|
||||
self.brush_slider_input = document.createElement('input');
|
||||
self.brush_slider_input.setAttribute('type', 'range');
|
||||
self.brush_slider_input.setAttribute('min', '1');
|
||||
self.brush_slider_input.setAttribute('max', '100');
|
||||
self.brush_slider_input.setAttribute('value', '10');
|
||||
const labelElement = document.createElement("label");
|
||||
labelElement.textContent = name;
|
||||
|
||||
divElement.appendChild(labelElement);
|
||||
divElement.appendChild(self.brush_slider_input);
|
||||
|
||||
self.brush_slider_input.addEventListener("change", callback);
|
||||
|
||||
return divElement;
|
||||
}
|
||||
|
||||
setlayout(imgCanvas, maskCanvas) {
|
||||
const self = this;
|
||||
|
||||
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||
var placeholder = document.createElement("div");
|
||||
placeholder.style.position = "relative";
|
||||
placeholder.style.height = "50px";
|
||||
|
||||
var bottom_panel = document.createElement("div");
|
||||
bottom_panel.style.position = "absolute";
|
||||
bottom_panel.style.bottom = "0px";
|
||||
bottom_panel.style.left = "20px";
|
||||
bottom_panel.style.right = "20px";
|
||||
bottom_panel.style.height = "50px";
|
||||
|
||||
var brush = document.createElement("div");
|
||||
brush.id = "brush";
|
||||
brush.style.backgroundColor = "transparent";
|
||||
brush.style.outline = "1px dashed black";
|
||||
brush.style.boxShadow = "0 0 0 1px white";
|
||||
brush.style.borderRadius = "50%";
|
||||
brush.style.MozBorderRadius = "50%";
|
||||
brush.style.WebkitBorderRadius = "50%";
|
||||
brush.style.position = "absolute";
|
||||
brush.style.zIndex = 8889;
|
||||
brush.style.pointerEvents = "none";
|
||||
this.brush = brush;
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
document.body.appendChild(brush);
|
||||
|
||||
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
self.brush_size = event.target.value;
|
||||
self.updateBrushPreview(self, null, null);
|
||||
});
|
||||
var clearButton = this.createLeftButton("Clear",
|
||||
() => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
||||
});
|
||||
var cancelButton = this.createRightButton("Cancel", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.close();
|
||||
});
|
||||
|
||||
this.saveButton = this.createRightButton("Save", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||
self.save();
|
||||
});
|
||||
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
|
||||
bottom_panel.appendChild(clearButton);
|
||||
bottom_panel.appendChild(this.saveButton);
|
||||
bottom_panel.appendChild(cancelButton);
|
||||
bottom_panel.appendChild(brush_size_slider);
|
||||
|
||||
imgCanvas.style.position = "relative";
|
||||
imgCanvas.style.top = "200";
|
||||
imgCanvas.style.left = "0";
|
||||
|
||||
maskCanvas.style.position = "absolute";
|
||||
}
|
||||
|
||||
show() {
|
||||
if(!this.is_layout_created) {
|
||||
// layout
|
||||
const imgCanvas = document.createElement('canvas');
|
||||
const maskCanvas = document.createElement('canvas');
|
||||
const backupCanvas = document.createElement('canvas');
|
||||
|
||||
imgCanvas.id = "imageCanvas";
|
||||
maskCanvas.id = "maskCanvas";
|
||||
backupCanvas.id = "backupCanvas";
|
||||
|
||||
this.setlayout(imgCanvas, maskCanvas);
|
||||
|
||||
// prepare content
|
||||
this.imgCanvas = imgCanvas;
|
||||
this.maskCanvas = maskCanvas;
|
||||
this.backupCanvas = backupCanvas;
|
||||
this.maskCtx = maskCanvas.getContext('2d');
|
||||
this.backupCtx = backupCanvas.getContext('2d');
|
||||
|
||||
this.setEventHandler(maskCanvas);
|
||||
|
||||
this.is_layout_created = true;
|
||||
|
||||
// replacement of onClose hook since close is not real close
|
||||
const self = this;
|
||||
const observer = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutation) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||
ComfyApp.onClipspaceEditorClosed();
|
||||
}
|
||||
|
||||
self.last_display_style = self.element.style.display;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
const config = { attributes: true };
|
||||
observer.observe(this.element, config);
|
||||
}
|
||||
|
||||
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
this.saveButton.innerText = "Save to node";
|
||||
}
|
||||
else {
|
||||
this.saveButton.innerText = "Save";
|
||||
}
|
||||
this.saveButton.disabled = false;
|
||||
|
||||
this.element.style.display = "block";
|
||||
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||
}
|
||||
|
||||
isOpened() {
|
||||
return this.element.style.display == "block";
|
||||
}
|
||||
|
||||
setImages(imgCanvas, backupCanvas) {
|
||||
const imgCtx = imgCanvas.getContext('2d');
|
||||
const backupCtx = backupCanvas.getContext('2d');
|
||||
const maskCtx = this.maskCtx;
|
||||
const maskCanvas = this.maskCanvas;
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||
|
||||
// image load
|
||||
const orig_image = new Image();
|
||||
window.addEventListener("resize", () => {
|
||||
// repositioning
|
||||
imgCanvas.width = window.innerWidth - 250;
|
||||
imgCanvas.height = window.innerHeight - 200;
|
||||
|
||||
// redraw image
|
||||
let drawWidth = orig_image.width;
|
||||
let drawHeight = orig_image.height;
|
||||
if (orig_image.width > imgCanvas.width) {
|
||||
drawWidth = imgCanvas.width;
|
||||
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
||||
}
|
||||
|
||||
if (drawHeight > imgCanvas.height) {
|
||||
drawHeight = imgCanvas.height;
|
||||
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
||||
}
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||
|
||||
// update mask
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCanvas.width = drawWidth;
|
||||
maskCanvas.height = drawHeight;
|
||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
});
|
||||
|
||||
const filepath = ComfyApp.clipspace.images;
|
||||
|
||||
const touched_image = new Image();
|
||||
|
||||
touched_image.onload = function() {
|
||||
backupCanvas.width = touched_image.width;
|
||||
backupCanvas.height = touched_image.height;
|
||||
|
||||
prepareRGB(touched_image, backupCanvas, backupCtx);
|
||||
};
|
||||
|
||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||
alpha_url.searchParams.delete('channel');
|
||||
alpha_url.searchParams.set('channel', 'a');
|
||||
touched_image.src = alpha_url;
|
||||
|
||||
// original image load
|
||||
orig_image.onload = function() {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
};
|
||||
|
||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||
rgb_url.searchParams.delete('channel');
|
||||
rgb_url.searchParams.set('channel', 'rgb');
|
||||
orig_image.src = rgb_url;
|
||||
this.image = orig_image;
|
||||
}
|
||||
|
||||
setEventHandler(maskCanvas) {
|
||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||
event.preventDefault();
|
||||
});
|
||||
|
||||
const self = this;
|
||||
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||
}
|
||||
|
||||
brush_size = 10;
|
||||
drawing_mode = false;
|
||||
lastx = -1;
|
||||
lasty = -1;
|
||||
lasttime = 0;
|
||||
|
||||
static handleKeyDown(event) {
|
||||
const self = MaskEditorDialog.instance;
|
||||
if (event.key === ']') {
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
} else if (event.key === '[') {
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
} else if(event.key === 'Enter') {
|
||||
self.save();
|
||||
}
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
static handlePointerUp(event) {
|
||||
event.preventDefault();
|
||||
MaskEditorDialog.instance.drawing_mode = false;
|
||||
}
|
||||
|
||||
updateBrushPreview(self) {
|
||||
const brush = self.brush;
|
||||
|
||||
var centerX = self.cursorX;
|
||||
var centerY = self.cursorY;
|
||||
|
||||
brush.style.width = self.brush_size * 2 + "px";
|
||||
brush.style.height = self.brush_size * 2 + "px";
|
||||
brush.style.left = (centerX - self.brush_size) + "px";
|
||||
brush.style.top = (centerY - self.brush_size) + "px";
|
||||
}
|
||||
|
||||
handleWheelEvent(self, event) {
|
||||
if(event.deltaY < 0)
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
else
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
}
|
||||
|
||||
draw_move(self, event) {
|
||||
event.preventDefault();
|
||||
|
||||
this.cursorX = event.pageX;
|
||||
this.cursorY = event.pageY;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
|
||||
if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) {
|
||||
var diff = performance.now() - self.lasttime;
|
||||
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
|
||||
var x = event.offsetX;
|
||||
var y = event.offsetY
|
||||
|
||||
if(event.offsetX == null) {
|
||||
x = event.targetTouches[0].clientX - maskRect.left;
|
||||
}
|
||||
|
||||
if(event.offsetY == null) {
|
||||
y = event.targetTouches[0].clientY - maskRect.top;
|
||||
}
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
// The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents.
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !this.drawing_mode)
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){
|
||||
brush_size *= this.last_pressure;
|
||||
}
|
||||
else {
|
||||
brush_size = this.brush_size;
|
||||
}
|
||||
|
||||
if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
else
|
||||
requestAnimationFrame(() => {
|
||||
self.maskCtx.beginPath();
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
|
||||
var dx = x - self.lastx;
|
||||
var dy = y - self.lasty;
|
||||
|
||||
var distance = Math.sqrt(dx * dx + dy * dy);
|
||||
var directionX = dx / distance;
|
||||
var directionY = dy / distance;
|
||||
|
||||
for (var i = 0; i < distance; i+=5) {
|
||||
var px = self.lastx + (directionX * i);
|
||||
var py = self.lasty + (directionY * i);
|
||||
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
}
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
});
|
||||
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
handlePointerDown(self, event) {
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
this.last_pressure = event.pressure;
|
||||
}
|
||||
|
||||
if ([0, 2, 5].includes(event.button)) {
|
||||
self.drawing_mode = true;
|
||||
|
||||
event.preventDefault();
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
|
||||
self.maskCtx.beginPath();
|
||||
if (event.button == 0) {
|
||||
self.maskCtx.fillStyle = "rgb(0,0,0)";
|
||||
self.maskCtx.globalCompositeOperation = "source-over";
|
||||
} else {
|
||||
self.maskCtx.globalCompositeOperation = "destination-out";
|
||||
}
|
||||
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
|
||||
self.maskCtx.fill();
|
||||
self.lastx = x;
|
||||
self.lasty = y;
|
||||
self.lasttime = performance.now();
|
||||
}
|
||||
}
|
||||
|
||||
async save() {
|
||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
backupCtx.drawImage(this.maskCanvas,
|
||||
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// paste mask data into alpha channel
|
||||
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
|
||||
const formData = new FormData();
|
||||
const filename = "clipspace-mask-" + performance.now() + ".png";
|
||||
|
||||
const item =
|
||||
{
|
||||
"filename": filename,
|
||||
"subfolder": "clipspace",
|
||||
"type": "input",
|
||||
};
|
||||
|
||||
if(ComfyApp.clipspace.images)
|
||||
ComfyApp.clipspace.images[0] = item;
|
||||
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||
|
||||
if(index >= 0)
|
||||
ComfyApp.clipspace.widgets[index].value = item;
|
||||
}
|
||||
|
||||
const dataURL = this.backupCanvas.toDataURL();
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
const original_blob = loadedImageToBlob(this.image);
|
||||
|
||||
formData.append('image', blob, filename);
|
||||
formData.append('original_image', original_blob);
|
||||
formData.append('type', "input");
|
||||
formData.append('subfolder', "clipspace");
|
||||
|
||||
this.saveButton.innerText = "Saving...";
|
||||
this.saveButton.disabled = true;
|
||||
await uploadMask(item, formData);
|
||||
ComfyApp.onClipspaceEditorSave();
|
||||
this.close();
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.MaskEditor",
|
||||
init(app) {
|
||||
ComfyApp.open_maskeditor =
|
||||
function () {
|
||||
const dlg = MaskEditorDialog.getInstance();
|
||||
if(!dlg.isOpened()) {
|
||||
dlg.show();
|
||||
}
|
||||
};
|
||||
|
||||
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
||||
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
|
||||
}
|
||||
});
|
||||
@ -300,7 +300,7 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
if (widget.type === "number") {
|
||||
if (widget.type === "number" || widget.type === "combo") {
|
||||
addValueControlWidget(this, widget, "fixed");
|
||||
}
|
||||
|
||||
|
||||
@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//when clicked on top of a node
|
||||
//and it is not interactive
|
||||
if (node && this.allow_interaction && !skip_action && !this.read_only) {
|
||||
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
|
||||
if (!this.live_mode && !node.flags.pinned) {
|
||||
this.bringToFront(node);
|
||||
} //if it wasn't selected?
|
||||
|
||||
//not dragging mouse to connect two slots
|
||||
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||
//Search for corner for resize
|
||||
if ( !skip_action &&
|
||||
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
||||
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
|
||||
//double clicking
|
||||
if (is_double_click && this.selected_nodes[node.id]) {
|
||||
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
|
||||
//double click node
|
||||
if (node.onDblClick) {
|
||||
node.onDblClick( e, pos, this );
|
||||
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
if (this.dragging_rectangle)
|
||||
{
|
||||
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
||||
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
this.ds.offset[1] += delta[1] / this.ds.scale;
|
||||
this.dirty_canvas = true;
|
||||
this.dirty_bgcanvas = true;
|
||||
} else if (this.allow_interaction && !this.read_only) {
|
||||
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
|
||||
if (this.connecting_node) {
|
||||
this.dirty_canvas = true;
|
||||
}
|
||||
|
||||
//get node over
|
||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||
|
||||
//remove mouseover flag
|
||||
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
||||
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
||||
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (show_text) {
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||
}
|
||||
break;
|
||||
case "toggle":
|
||||
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
if (show_text) {
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(
|
||||
w.name + " " + Number(w.value).toFixed(3),
|
||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||
widget_width * 0.5,
|
||||
y + H * 0.7
|
||||
);
|
||||
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
}
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
if (w.type == "number") {
|
||||
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//ctx.stroke();
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
event,
|
||||
active_widget
|
||||
) {
|
||||
if (!node.widgets || !node.widgets.length) {
|
||||
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
canvas.graph.add(group);
|
||||
};
|
||||
|
||||
/**
|
||||
* Determines the furthest nodes in each direction
|
||||
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.getBoundaryNodes = function(nodes) {
|
||||
let top = null;
|
||||
let right = null;
|
||||
let bottom = null;
|
||||
let left = null;
|
||||
for (const nID in nodes) {
|
||||
const node = nodes[nID];
|
||||
const [x, y] = node.pos;
|
||||
const [width, height] = node.size;
|
||||
|
||||
if (top === null || y < top.pos[1]) {
|
||||
top = node;
|
||||
}
|
||||
if (right === null || x + width > right.pos[0] + right.size[0]) {
|
||||
right = node;
|
||||
}
|
||||
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
|
||||
bottom = node;
|
||||
}
|
||||
if (left === null || x < left.pos[0]) {
|
||||
left = node;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"top": top,
|
||||
"right": right,
|
||||
"bottom": bottom,
|
||||
"left": left
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Determines the furthest nodes in each direction for the currently selected nodes
|
||||
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||
*/
|
||||
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
|
||||
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {LGraphNode[]} nodes a list of nodes
|
||||
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
|
||||
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
|
||||
*/
|
||||
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
|
||||
if (!nodes) {
|
||||
return;
|
||||
}
|
||||
|
||||
const canvas = LGraphCanvas.active_canvas;
|
||||
let boundaryNodes = []
|
||||
if (align_to === undefined) {
|
||||
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
|
||||
} else {
|
||||
boundaryNodes = {
|
||||
"top": align_to,
|
||||
"right": align_to,
|
||||
"bottom": align_to,
|
||||
"left": align_to
|
||||
}
|
||||
}
|
||||
|
||||
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
|
||||
switch (direction) {
|
||||
case "right":
|
||||
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
|
||||
break;
|
||||
case "left":
|
||||
node.pos[0] = boundaryNodes["left"].pos[0];
|
||||
break;
|
||||
case "top":
|
||||
node.pos[1] = boundaryNodes["top"].pos[1];
|
||||
break;
|
||||
case "bottom":
|
||||
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
canvas.dirty_canvas = true;
|
||||
canvas.dirty_bgcanvas = true;
|
||||
};
|
||||
|
||||
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
|
||||
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||
event: event,
|
||||
callback: inner_clicked,
|
||||
parentMenu: prev_menu,
|
||||
});
|
||||
|
||||
function inner_clicked(value) {
|
||||
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
|
||||
}
|
||||
}
|
||||
|
||||
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
||||
|
||||
var canvas = LGraphCanvas.active_canvas;
|
||||
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
||||
}*/
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onGroupAlign,
|
||||
})
|
||||
}
|
||||
|
||||
if (this._graph_stack && this._graph_stack.length > 0) {
|
||||
options.push(null, {
|
||||
content: "Close subgraph",
|
||||
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
callback: LGraphCanvas.onMenuNodeToSubgraph
|
||||
});
|
||||
|
||||
if (Object.keys(this.selected_nodes).length > 1) {
|
||||
options.push({
|
||||
content: "Align Selected To",
|
||||
has_submenu: true,
|
||||
callback: LGraphCanvas.onNodeAlign,
|
||||
})
|
||||
}
|
||||
|
||||
options.push(null, {
|
||||
content: "Remove",
|
||||
disabled: !(node.removable !== false && !node.block_delete ),
|
||||
|
||||
@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
|
||||
|
||||
if (res.status !== 200) {
|
||||
throw {
|
||||
response: await res.text(),
|
||||
response: await res.json(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
|
||||
/**
|
||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||
@ -25,6 +25,9 @@ export class ComfyApp {
|
||||
* @type {serialized node object}
|
||||
*/
|
||||
static clipspace = null;
|
||||
static clipspace_invalidate_handler = null;
|
||||
static open_maskeditor = null;
|
||||
static clipspace_return_node = null;
|
||||
|
||||
constructor() {
|
||||
this.ui = new ComfyUI(this);
|
||||
@ -48,6 +51,114 @@ export class ComfyApp {
|
||||
this.shiftDown = false;
|
||||
}
|
||||
|
||||
static isImageNode(node) {
|
||||
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||
}
|
||||
|
||||
static onClipspaceEditorSave() {
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
|
||||
}
|
||||
}
|
||||
|
||||
static onClipspaceEditorClosed() {
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
}
|
||||
|
||||
static copyToClipspace(node) {
|
||||
var widgets = null;
|
||||
if(node.widgets) {
|
||||
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
var imgs = undefined;
|
||||
var orig_imgs = undefined;
|
||||
if(node.imgs != undefined) {
|
||||
imgs = [];
|
||||
orig_imgs = [];
|
||||
|
||||
for (let i = 0; i < node.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = node.imgs[i].src;
|
||||
orig_imgs[i] = imgs[i];
|
||||
}
|
||||
}
|
||||
|
||||
var selectedIndex = 0;
|
||||
if(node.imageIndex) {
|
||||
selectedIndex = node.imageIndex;
|
||||
}
|
||||
|
||||
ComfyApp.clipspace = {
|
||||
'widgets': widgets,
|
||||
'imgs': imgs,
|
||||
'original_imgs': orig_imgs,
|
||||
'images': node.images,
|
||||
'selectedIndex': selectedIndex,
|
||||
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
||||
};
|
||||
|
||||
ComfyApp.clipspace_return_node = null;
|
||||
|
||||
if(ComfyApp.clipspace_invalidate_handler) {
|
||||
ComfyApp.clipspace_invalidate_handler();
|
||||
}
|
||||
}
|
||||
|
||||
static pasteFromClipspace(node) {
|
||||
if(ComfyApp.clipspace) {
|
||||
// image paste
|
||||
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||
if(node.images && ComfyApp.clipspace.images) {
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||
}
|
||||
else
|
||||
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.imgs) {
|
||||
// deep-copy to cut link with clipspace
|
||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||
const img = new Image();
|
||||
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||
node.imgs = [img];
|
||||
node.imageIndex = 0;
|
||||
}
|
||||
else {
|
||||
const imgs = [];
|
||||
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
||||
imgs[i] = new Image();
|
||||
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
||||
node.imgs = imgs;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(node.widgets) {
|
||||
if(ComfyApp.clipspace.images) {
|
||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index >= 0) {
|
||||
node.widgets[index].value = clip_image;
|
||||
}
|
||||
}
|
||||
if(ComfyApp.clipspace.widgets) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop && prop.type != 'button') {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke an extension callback
|
||||
* @param {keyof ComfyExtension} method The extension callback to execute
|
||||
@ -137,81 +248,30 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
options.push(
|
||||
{
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => {
|
||||
var widgets = null;
|
||||
if(this.widgets) {
|
||||
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
let img = new Image();
|
||||
var imgs = undefined;
|
||||
if(this.imgs != undefined) {
|
||||
img.src = this.imgs[0].src;
|
||||
imgs = [img];
|
||||
}
|
||||
// prevent conflict of clipspace content
|
||||
if(!ComfyApp.clipspace_return_node) {
|
||||
options.push({
|
||||
content: "Copy (Clipspace)",
|
||||
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||
});
|
||||
|
||||
ComfyApp.clipspace = {
|
||||
'widgets': widgets,
|
||||
'imgs': imgs,
|
||||
'original_imgs': imgs,
|
||||
'images': this.images
|
||||
};
|
||||
}
|
||||
});
|
||||
if(ComfyApp.clipspace != null) {
|
||||
options.push({
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||
});
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace != null) {
|
||||
options.push(
|
||||
{
|
||||
content: "Paste (Clipspace)",
|
||||
callback: () => {
|
||||
if(ComfyApp.clipspace != null) {
|
||||
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop) {
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// image paste
|
||||
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
|
||||
var filename = "";
|
||||
if(this.images && ComfyApp.clipspace.images) {
|
||||
this.images = ComfyApp.clipspace.images;
|
||||
}
|
||||
|
||||
if(ComfyApp.clipspace.images != undefined) {
|
||||
const clip_image = ComfyApp.clipspace.images[0];
|
||||
if(clip_image.subfolder != '')
|
||||
filename = `${clip_image.subfolder}/`;
|
||||
filename += `${clip_image.filename} [${clip_image.type}]`;
|
||||
}
|
||||
else if(ComfyApp.clipspace.widgets != undefined) {
|
||||
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index_in_clip >= 0) {
|
||||
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
|
||||
}
|
||||
}
|
||||
|
||||
const index = this.widgets.findIndex(obj => obj.name === 'image');
|
||||
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
|
||||
this.imgs = ComfyApp.clipspace.imgs;
|
||||
|
||||
this.widgets[index].value = filename;
|
||||
if(this.widgets_values != undefined) {
|
||||
this.widgets_values[index] = filename;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.trigger('changed');
|
||||
if(ComfyApp.isImageNode(this)) {
|
||||
options.push({
|
||||
content: "Open in MaskEditor",
|
||||
callback: (obj) => {
|
||||
ComfyApp.copyToClipspace(this);
|
||||
ComfyApp.clipspace_return_node = this;
|
||||
ComfyApp.open_maskeditor();
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -263,6 +323,34 @@ export class ComfyApp {
|
||||
*/
|
||||
#addDrawBackgroundHandler(node) {
|
||||
const app = this;
|
||||
|
||||
function getImageTop(node) {
|
||||
let shiftY;
|
||||
if (node.imageOffset != null) {
|
||||
shiftY = node.imageOffset;
|
||||
} else {
|
||||
if (node.widgets?.length) {
|
||||
const w = node.widgets[node.widgets.length - 1];
|
||||
shiftY = w.last_y;
|
||||
if (w.computeSize) {
|
||||
shiftY += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
} else {
|
||||
shiftY = node.computeSize()[1];
|
||||
}
|
||||
}
|
||||
return shiftY;
|
||||
}
|
||||
|
||||
node.prototype.setSizeForImage = function () {
|
||||
const minHeight = getImageTop(this) + 220;
|
||||
if (this.size[1] < minHeight) {
|
||||
this.setSize([this.size[0], minHeight]);
|
||||
}
|
||||
};
|
||||
|
||||
node.prototype.onDrawBackground = function (ctx) {
|
||||
if (!this.flags.collapsed) {
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
@ -283,9 +371,7 @@ export class ComfyApp {
|
||||
).then((imgs) => {
|
||||
if (this.images === output.images) {
|
||||
this.imgs = imgs.filter(Boolean);
|
||||
if (this.size[1] < 100) {
|
||||
this.size[1] = 250;
|
||||
}
|
||||
this.setSizeForImage?.();
|
||||
app.graph.setDirtyCanvas(true);
|
||||
}
|
||||
});
|
||||
@ -310,12 +396,7 @@ export class ComfyApp {
|
||||
this.imageIndex = imageIndex = 0;
|
||||
}
|
||||
|
||||
let shiftY;
|
||||
if (this.imageOffset != null) {
|
||||
shiftY = this.imageOffset;
|
||||
} else {
|
||||
shiftY = this.computeSize()[1];
|
||||
}
|
||||
const shiftY = getImageTop(this);
|
||||
|
||||
let dw = this.size[0];
|
||||
let dh = this.size[1];
|
||||
@ -703,7 +784,7 @@ export class ComfyApp {
|
||||
ctx.globalAlpha = 0.8;
|
||||
ctx.beginPath();
|
||||
if (shape == LiteGraph.BOX_SHAPE)
|
||||
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
|
||||
ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
|
||||
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed))
|
||||
ctx.roundRect(
|
||||
-6,
|
||||
@ -715,12 +796,11 @@ export class ComfyApp {
|
||||
else if (shape == LiteGraph.CARD_SHAPE)
|
||||
ctx.roundRect(
|
||||
-6,
|
||||
-6 + LiteGraph.NODE_TITLE_HEIGHT,
|
||||
-6 - LiteGraph.NODE_TITLE_HEIGHT,
|
||||
12 + size[0] + 1,
|
||||
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
|
||||
this.round_radius * 2,
|
||||
2
|
||||
);
|
||||
[this.round_radius * 2, this.round_radius * 2, 2, 2]
|
||||
);
|
||||
else if (shape == LiteGraph.CIRCLE_SHAPE)
|
||||
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
|
||||
ctx.strokeStyle = color;
|
||||
@ -822,7 +902,9 @@ export class ComfyApp {
|
||||
await this.#loadExtensions();
|
||||
|
||||
// Create and mount the LiteGraph in the DOM
|
||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
||||
const mainCanvas = document.createElement("canvas")
|
||||
mainCanvas.style.touchAction = "none"
|
||||
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
@ -941,7 +1023,8 @@ export class ComfyApp {
|
||||
for (const o in nodeData["output"]) {
|
||||
const output = nodeData["output"][o];
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
this.addOutput(outputName, output);
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
@ -1186,7 +1269,7 @@ export class ComfyApp {
|
||||
try {
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.ui.dialog.show(error.response || error.toString());
|
||||
this.ui.dialog.show(error.response.error || error.toString());
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1232,6 +1315,11 @@ export class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1262,12 +1350,12 @@ export class ComfyApp {
|
||||
|
||||
for(const widgetNum in node.widgets) {
|
||||
const widget = node.widgets[widgetNum]
|
||||
|
||||
if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) {
|
||||
widget.options.values = def["input"]["required"][widget.name][0];
|
||||
|
||||
if(!widget.options.values.includes(widget.value)) {
|
||||
if(widget.name != 'image' && !widget.options.values.includes(widget.value)) {
|
||||
widget.value = widget.options.values[0];
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +47,22 @@ export function getPngMetadata(file) {
|
||||
});
|
||||
}
|
||||
|
||||
export function getLatentMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
const safetensorsData = new Uint8Array(event.target.result);
|
||||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
export async function importA1111(graph, parameters) {
|
||||
const p = parameters.lastIndexOf("\nSteps:");
|
||||
if (p > -1) {
|
||||
|
||||
@ -465,7 +465,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png",
|
||||
accept: ".json,image/png,.latent",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
@ -581,6 +581,7 @@ export class ComfyUI {
|
||||
}),
|
||||
$el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }),
|
||||
$el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }),
|
||||
$el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }),
|
||||
$el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => {
|
||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||
app.clean();
|
||||
|
||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
|
||||
var v = valueControl.value;
|
||||
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
switch (v) {
|
||||
case "increment":
|
||||
current_index += 1;
|
||||
break;
|
||||
case "decrement":
|
||||
current_index -= 1;
|
||||
break;
|
||||
case "randomize":
|
||||
current_index = Math.floor(Math.random() * current_length);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
} else { //number
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
return valueControl;
|
||||
};
|
||||
@ -130,17 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
computeSize(node.size);
|
||||
}
|
||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||
const t = ctx.getTransform();
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin / window.devicePixelRatio, (margin + y) / window.devicePixelRatio);
|
||||
|
||||
Object.assign(this.inputEl.style, {
|
||||
left: `${(t.a * margin + t.e) / window.devicePixelRatio}px`,
|
||||
top: `${(t.d * (y + widgetHeight - margin - 3) + t.f) / window.devicePixelRatio}px`,
|
||||
width: `${(widgetWidth - margin * 2 - 3) * t.a / window.devicePixelRatio}px`,
|
||||
background: (!node.color)?'':node.color,
|
||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d / window.devicePixelRatio}px`,
|
||||
transformOrigin: "0 0",
|
||||
transform: transform,
|
||||
left: "0px",
|
||||
top: "0px",
|
||||
width: `${(widgetWidth - (margin * 2)) / window.devicePixelRatio}px`,
|
||||
height: `${(this.parent.inputHeight - (margin * 2)) / window.devicePixelRatio}px`,
|
||||
position: "absolute",
|
||||
zIndex: 1,
|
||||
fontSize: `${t.d * 10.0 / window.devicePixelRatio}px`,
|
||||
background: (!node.color)?'':node.color,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
},
|
||||
@ -260,22 +292,51 @@ export const ComfyWidgets = {
|
||||
let uploadWidget;
|
||||
|
||||
function showImage(name) {
|
||||
// Position the image somewhere sensible
|
||||
if (!node.imageOffset) {
|
||||
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
|
||||
}
|
||||
|
||||
const img = new Image();
|
||||
img.onload = () => {
|
||||
node.imgs = [img];
|
||||
app.graph.setDirtyCanvas(true);
|
||||
};
|
||||
img.src = `/view?filename=${name}&type=input`;
|
||||
if ((node.size[1] - node.imageOffset) < 100) {
|
||||
node.size[1] = 250 + node.imageOffset;
|
||||
let folder_separator = name.lastIndexOf("/");
|
||||
let subfolder = "";
|
||||
if (folder_separator > -1) {
|
||||
subfolder = name.substring(0, folder_separator);
|
||||
name = name.substring(folder_separator + 1);
|
||||
}
|
||||
img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`;
|
||||
node.setSizeForImage?.();
|
||||
}
|
||||
|
||||
var default_value = imageWidget.value;
|
||||
Object.defineProperty(imageWidget, "value", {
|
||||
set : function(value) {
|
||||
this._real_value = value;
|
||||
},
|
||||
|
||||
get : function() {
|
||||
let value = "";
|
||||
if (this._real_value) {
|
||||
value = this._real_value;
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
|
||||
if (value.filename) {
|
||||
let real_value = value;
|
||||
value = "";
|
||||
if (real_value.subfolder) {
|
||||
value = real_value.subfolder + "/";
|
||||
}
|
||||
|
||||
value += real_value.filename;
|
||||
|
||||
if(real_value.type && real_value.type !== "input")
|
||||
value += ` [${real_value.type}]`;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
});
|
||||
|
||||
// Add our own callback to the combo widget to render an image when it changes
|
||||
const cb = node.callback;
|
||||
imageWidget.callback = function () {
|
||||
|
||||
@ -39,6 +39,8 @@ body {
|
||||
padding: 2px;
|
||||
resize: none;
|
||||
border: none;
|
||||
box-sizing: border-box;
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.comfy-modal {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user