mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Merge remote-tracking branch 'upstream/master' into addBatchIndex
This commit is contained in:
commit
abc3d0baf2
@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
|
"""DPM-Solver++(2M) SDE."""
|
||||||
|
|
||||||
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
old_denoised = None
|
||||||
|
h_last = None
|
||||||
|
h = None
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
# Denoising step
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
# DPM-Solver++(2M) SDE
|
||||||
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = s - t
|
||||||
|
eta_h = eta * h
|
||||||
|
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||||
|
|
||||||
|
if old_denoised is not None:
|
||||||
|
r = h_last / h
|
||||||
|
if solver_type == 'heun':
|
||||||
|
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||||
|
elif solver_type == 'midpoint':
|
||||||
|
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||||
|
|
||||||
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
|
old_denoised = denoised
|
||||||
|
h_last = h
|
||||||
|
return x
|
||||||
|
|||||||
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
return x+h
|
return x+h
|
||||||
|
|
||||||
|
def slice_attention(q, k, v):
|
||||||
|
r1 = torch.zeros_like(k, device=q.device)
|
||||||
|
scale = (int(q.shape[-1])**(-0.5))
|
||||||
|
|
||||||
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||||
|
|
||||||
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
|
del s2
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
steps *= 2
|
||||||
|
if steps > 128:
|
||||||
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
|
||||||
|
return r1
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
b,c,h,w = q.shape
|
||||||
scale = (int(c)**(-0.5))
|
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
q = q.reshape(b,c,h*w)
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
v = v.reshape(b,c,h*w)
|
v = v.reshape(b,c,h*w)
|
||||||
|
|
||||||
r1 = torch.zeros_like(k, device=q.device)
|
r1 = slice_attention(q, k, v)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
|
||||||
|
|
||||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
|
||||||
del s2
|
|
||||||
break
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
steps *= 2
|
|
||||||
if steps > 128:
|
|
||||||
raise e
|
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
|
||||||
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
h_ = r1.reshape(b,c,h,w)
|
||||||
del r1
|
del r1
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
B, C, H, W = q.shape
|
||||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
.reshape(B, t.shape[1], 1, C)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B * 1, t.shape[1], C)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
|
|
||||||
out = (
|
try:
|
||||||
out.unsqueeze(0)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
.reshape(B, 1, out.shape[1], C)
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
.permute(0, 2, 1, 3)
|
except model_management.OOM_EXCEPTION as e:
|
||||||
.reshape(B, out.shape[1], C)
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
"""
|
"""
|
||||||
B, N, _ = metric.shape
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
if r <= 0:
|
if r <= 0 or w == 1 or h == 1:
|
||||||
return do_nothing, do_nothing
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|||||||
@ -127,6 +127,32 @@ if args.cpu:
|
|||||||
|
|
||||||
print(f"Set vram state to: {vram_state.name}")
|
print(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
|
def get_torch_device():
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if directml_enabled:
|
||||||
|
global directml_device
|
||||||
|
return directml_device
|
||||||
|
if vram_state == VRAMState.MPS:
|
||||||
|
return torch.device("mps")
|
||||||
|
if vram_state == VRAMState.CPU:
|
||||||
|
return torch.device("cpu")
|
||||||
|
else:
|
||||||
|
if xpu_available:
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
|
def get_torch_device_name(device):
|
||||||
|
if hasattr(device, 'type'):
|
||||||
|
return "{}".format(device.type)
|
||||||
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Using device:", get_torch_device_name(get_torch_device()))
|
||||||
|
except:
|
||||||
|
print("Could not pick default device.")
|
||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
current_gpu_controlnets = []
|
current_gpu_controlnets = []
|
||||||
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_torch_device():
|
|
||||||
global xpu_available
|
|
||||||
global directml_enabled
|
|
||||||
if directml_enabled:
|
|
||||||
global directml_device
|
|
||||||
return directml_device
|
|
||||||
if vram_state == VRAMState.MPS:
|
|
||||||
return torch.device("mps")
|
|
||||||
if vram_state == VRAMState.CPU:
|
|
||||||
return torch.device("cpu")
|
|
||||||
else:
|
|
||||||
if xpu_available:
|
|
||||||
return torch.device("xpu")
|
|
||||||
else:
|
|
||||||
return torch.cuda.current_device()
|
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
|||||||
@ -2,17 +2,26 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def prepare_noise(latent_image, seed, skip=0):
|
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||||
"""
|
"""
|
||||||
creates random noise given a latent image and a seed.
|
creates random noise given a latent image and a seed.
|
||||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||||
"""
|
"""
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.manual_seed(seed)
|
||||||
for _ in range(skip):
|
if noise_inds is None:
|
||||||
|
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
|
||||||
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||||
|
noises = []
|
||||||
|
for i in range(unique_inds[-1]+1):
|
||||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
if i in unique_inds:
|
||||||
return noise
|
noises.append(noise)
|
||||||
|
noises = [noises[i] for i in inverse]
|
||||||
|
noises = torch.cat(noises, axis=0)
|
||||||
|
return noises
|
||||||
|
|
||||||
def prepare_mask(noise_mask, shape, device):
|
def prepare_mask(noise_mask, shape, device):
|
||||||
"""ensures noise mask is of proper dimensions"""
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
|||||||
@ -6,6 +6,10 @@ import contextlib
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
|
import math
|
||||||
|
|
||||||
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
|
return abs(a*b) // math.gcd(a, b)
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if c1.keys() != c2.keys():
|
if c1.keys() != c2.keys():
|
||||||
return False
|
return False
|
||||||
if 'c_crossattn' in c1:
|
if 'c_crossattn' in c1:
|
||||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
s1 = c1['c_crossattn'].shape
|
||||||
return False
|
s2 = c2['c_crossattn'].shape
|
||||||
|
if s1 != s2:
|
||||||
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
|
return False
|
||||||
|
|
||||||
|
mult_min = lcm(s1[1], s2[1])
|
||||||
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
|
return False
|
||||||
if 'c_concat' in c1:
|
if 'c_concat' in c1:
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||||
return False
|
return False
|
||||||
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
c_crossattn = []
|
c_crossattn = []
|
||||||
c_concat = []
|
c_concat = []
|
||||||
c_adm = []
|
c_adm = []
|
||||||
|
crossattn_max_len = 0
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
if 'c_crossattn' in x:
|
if 'c_crossattn' in x:
|
||||||
c_crossattn.append(x['c_crossattn'])
|
c = x['c_crossattn']
|
||||||
|
if crossattn_max_len == 0:
|
||||||
|
crossattn_max_len = c.shape[1]
|
||||||
|
else:
|
||||||
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||||
|
c_crossattn.append(c)
|
||||||
if 'c_concat' in x:
|
if 'c_concat' in x:
|
||||||
c_concat.append(x['c_concat'])
|
c_concat.append(x['c_concat'])
|
||||||
if 'c_adm' in x:
|
if 'c_adm' in x:
|
||||||
c_adm.append(x['c_adm'])
|
c_adm.append(x['c_adm'])
|
||||||
out = {}
|
out = {}
|
||||||
if len(c_crossattn) > 0:
|
c_crossattn_out = []
|
||||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
for c in c_crossattn:
|
||||||
|
if c.shape[1] < crossattn_max_len:
|
||||||
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
|
c_crossattn_out.append(c)
|
||||||
|
|
||||||
|
if len(c_crossattn_out) > 0:
|
||||||
|
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||||
if len(c_concat) > 0:
|
if len(c_concat) > 0:
|
||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
if len(c_adm) > 0:
|
if len(c_adm) > 0:
|
||||||
@ -471,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||||||
|
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -508,6 +532,8 @@ class KSampler:
|
|||||||
|
|
||||||
if self.scheduler == "karras":
|
if self.scheduler == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
|
elif self.scheduler == "exponential":
|
||||||
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
elif self.scheduler == "normal":
|
elif self.scheduler == "normal":
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
elif self.scheduler == "simple":
|
elif self.scheduler == "simple":
|
||||||
|
|||||||
17
comfy/sd.py
17
comfy/sd.py
@ -581,12 +581,9 @@ class VAE:
|
|||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def resize_image_to(tensor, target_latent_tensor, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
|
|
||||||
target_batch_size = target_latent_tensor.shape[0]
|
|
||||||
|
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
if current_batch_size == 1:
|
if current_batch_size == 1:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -623,7 +620,9 @@ class ControlNet:
|
|||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
if self.control_model.dtype == torch.float16:
|
if self.control_model.dtype == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
@ -794,10 +793,14 @@ class T2IAdapter:
|
|||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
|
self.control_input = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||||
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
if self.control_input is None:
|
||||||
self.t2i_model.to(self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint)
|
self.control_input = self.t2i_model(self.cond_hint)
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class MaskToImage:
|
|||||||
FUNCTION = "mask_to_image"
|
FUNCTION = "mask_to_image"
|
||||||
|
|
||||||
def mask_to_image(self, mask):
|
def mask_to_image(self, mask):
|
||||||
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
|
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
class ImageToMask:
|
class ImageToMask:
|
||||||
|
|||||||
@ -59,6 +59,12 @@ class Blend:
|
|||||||
def g(self, x):
|
def g(self, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
|
def gaussian_kernel(kernel_size: int, sigma: float):
|
||||||
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||||
|
d = torch.sqrt(x * x + y * y)
|
||||||
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
|
return g / g.sum()
|
||||||
|
|
||||||
class Blur:
|
class Blur:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@ -88,12 +94,6 @@ class Blur:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
|
||||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
|
||||||
d = torch.sqrt(x * x + y * y)
|
|
||||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
|
||||||
return g / g.sum()
|
|
||||||
|
|
||||||
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||||
if blur_radius == 0:
|
if blur_radius == 0:
|
||||||
return (image,)
|
return (image,)
|
||||||
@ -101,10 +101,11 @@ class Blur:
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = blur_radius * 2 + 1
|
kernel_size = blur_radius * 2 + 1
|
||||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||||
|
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||||
blurred = blurred.permute(0, 2, 3, 1)
|
blurred = blurred.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
return (blurred,)
|
return (blurred,)
|
||||||
@ -167,9 +168,15 @@ class Sharpen:
|
|||||||
"max": 31,
|
"max": 31,
|
||||||
"step": 1
|
"step": 1
|
||||||
}),
|
}),
|
||||||
"alpha": ("FLOAT", {
|
"sigma": ("FLOAT", {
|
||||||
"default": 1.0,
|
"default": 1.0,
|
||||||
"min": 0.1,
|
"min": 0.1,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.0,
|
||||||
"max": 5.0,
|
"max": 5.0,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
}),
|
}),
|
||||||
@ -181,21 +188,21 @@ class Sharpen:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||||
if sharpen_radius == 0:
|
if sharpen_radius == 0:
|
||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = sharpen_radius * 2 + 1
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
kernel[center, center] = kernel_size**2
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
kernel *= alpha
|
|
||||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||||
|
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
result = torch.clamp(sharpened, 0, 1)
|
result = torch.clamp(sharpened, 0, 1)
|
||||||
|
|||||||
108
comfy_extras/nodes_rebatch.py
Normal file
108
comfy_extras/nodes_rebatch.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class LatentRebatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "latents": ("LATENT",),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
OUTPUT_IS_LIST = (True, )
|
||||||
|
|
||||||
|
FUNCTION = "rebatch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_batch(latents, list_ind, offset):
|
||||||
|
'''prepare a batch out of the list of latents'''
|
||||||
|
samples = latents[list_ind]['samples']
|
||||||
|
shape = samples.shape
|
||||||
|
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||||
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||||
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||||
|
if mask.shape[0] < samples.shape[0]:
|
||||||
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||||
|
if 'batch_index' in latents[list_ind]:
|
||||||
|
batch_inds = latents[list_ind]['batch_index']
|
||||||
|
else:
|
||||||
|
batch_inds = [x+offset for x in range(shape[0])]
|
||||||
|
return samples, mask, batch_inds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_slices(indexable, num, batch_size):
|
||||||
|
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
||||||
|
slices = []
|
||||||
|
for i in range(num):
|
||||||
|
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||||
|
if num * batch_size < len(indexable):
|
||||||
|
return slices, indexable[num * batch_size:]
|
||||||
|
else:
|
||||||
|
return slices, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def slice_batch(batch, num, batch_size):
|
||||||
|
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||||
|
return list(zip(*result))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat_batch(batch1, batch2):
|
||||||
|
if batch1[0] is None:
|
||||||
|
return batch2
|
||||||
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def rebatch(self, latents, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
current_batch = (None, None, None)
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
for i in range(len(latents)):
|
||||||
|
# fetch new entry of list
|
||||||
|
#samples, masks, indices = self.get_batch(latents, i)
|
||||||
|
next_batch = self.get_batch(latents, i, processed)
|
||||||
|
processed += len(next_batch[2])
|
||||||
|
# set to current if current is None
|
||||||
|
if current_batch[0] is None:
|
||||||
|
current_batch = next_batch
|
||||||
|
# add previous to list if dimensions do not match
|
||||||
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||||
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
current_batch = next_batch
|
||||||
|
# cat if everything checks out
|
||||||
|
else:
|
||||||
|
current_batch = self.cat_batch(current_batch, next_batch)
|
||||||
|
|
||||||
|
# add to list if dimensions gone above target batch size
|
||||||
|
if current_batch[0].shape[0] > batch_size:
|
||||||
|
num = current_batch[0].shape[0] // batch_size
|
||||||
|
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||||
|
|
||||||
|
current_batch = remainder
|
||||||
|
|
||||||
|
#add remainder
|
||||||
|
if current_batch[0] is not None:
|
||||||
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
|
||||||
|
#get rid of empty masks
|
||||||
|
for s in output_list:
|
||||||
|
if s['noise_mask'].mean() == 1.0:
|
||||||
|
del s['noise_mask']
|
||||||
|
|
||||||
|
return (output_list,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"RebatchLatents": LatentRebatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"RebatchLatents": "Rebatch Latents",
|
||||||
|
}
|
||||||
@ -17,7 +17,7 @@ class UpscaleModelLoader:
|
|||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
out = model_loading.load_state_dict(sd).eval()
|
out = model_loading.load_state_dict(sd).eval()
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
|||||||
169
execution.py
169
execution.py
@ -6,6 +6,7 @@ import threading
|
|||||||
import heapq
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
import gc
|
import gc
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
@ -26,21 +27,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
else:
|
else:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||||
input_data_all[x] = input_data
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
h = valid_inputs["hidden"]
|
h = valid_inputs["hidden"]
|
||||||
for x in h:
|
for x in h:
|
||||||
if h[x] == "PROMPT":
|
if h[x] == "PROMPT":
|
||||||
input_data_all[x] = prompt
|
input_data_all[x] = [prompt]
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
if "extra_pnginfo" in extra_data:
|
if "extra_pnginfo" in extra_data:
|
||||||
input_data_all[x] = extra_data['extra_pnginfo']
|
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = unique_id
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
|
# check if node wants the lists
|
||||||
|
intput_is_list = False
|
||||||
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
|
intput_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
|
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||||
|
|
||||||
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
|
def slice_dict(d, i):
|
||||||
|
d_new = dict()
|
||||||
|
for k,v in d.items():
|
||||||
|
d_new[k] = v[i if len(v) > i else -1]
|
||||||
|
return d_new
|
||||||
|
|
||||||
|
results = []
|
||||||
|
if intput_is_list:
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
|
else:
|
||||||
|
for i in range(max_len_input):
|
||||||
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all):
|
||||||
|
|
||||||
|
results = []
|
||||||
|
uis = []
|
||||||
|
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||||
|
|
||||||
|
for r in return_values:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
if 'ui' in r:
|
||||||
|
uis.append(r['ui'])
|
||||||
|
if 'result' in r:
|
||||||
|
results.append(r['result'])
|
||||||
|
else:
|
||||||
|
results.append(r)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
if len(results) > 0:
|
||||||
|
# check which outputs need concatenating
|
||||||
|
output_is_list = [False] * len(results[0])
|
||||||
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||||
|
output_is_list = obj.OUTPUT_IS_LIST
|
||||||
|
|
||||||
|
# merge node execution results
|
||||||
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
|
if is_list:
|
||||||
|
output.append([x for o in results for x in o[i]])
|
||||||
|
else:
|
||||||
|
output.append([o[i] for o in results])
|
||||||
|
|
||||||
|
ui = dict()
|
||||||
|
if len(uis) > 0:
|
||||||
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
|
return output, ui
|
||||||
|
|
||||||
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
@ -55,21 +117,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||||
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
|
|
||||||
nodes.before_node_execution()
|
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
outputs[unique_id] = output_data
|
||||||
if "ui" in outputs[unique_id]:
|
if len(output_ui) > 0:
|
||||||
|
outputs_ui[unique_id] = output_ui
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
if "result" in outputs[unique_id]:
|
|
||||||
outputs[unique_id] = outputs[unique_id]["result"]
|
|
||||||
executed.add(unique_id)
|
executed.add(unique_id)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item):
|
def recursive_will_execute(prompt, outputs, current_item):
|
||||||
@ -105,7 +166,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||||
if input_data_all is not None:
|
if input_data_all is not None:
|
||||||
try:
|
try:
|
||||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
except:
|
except:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
@ -144,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
def execute(self, prompt, extra_data={}):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
@ -155,6 +218,10 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
|
execution_start_time = time.perf_counter()
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
#delete cached outputs if nodes don't exist for them
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -169,32 +236,34 @@ class PromptExecutor:
|
|||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
|
|
||||||
current_outputs = set(self.outputs.keys())
|
current_outputs = set(self.outputs.keys())
|
||||||
|
for x in list(self.outputs_ui.keys()):
|
||||||
|
if x not in current_outputs:
|
||||||
|
d = self.outputs_ui.pop(x)
|
||||||
|
del d
|
||||||
|
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||||
executed = set()
|
executed = set()
|
||||||
try:
|
try:
|
||||||
to_execute = []
|
to_execute = []
|
||||||
for x in prompt:
|
for x in list(execute_outputs):
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
to_execute += [(0, x)]
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
to_execute += [(0, x)]
|
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while len(to_execute) > 0:
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||||
x = to_execute.pop(0)[-1]
|
x = to_execute.pop(0)[-1]
|
||||||
|
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
|
||||||
if hasattr(class_, 'OUTPUT_NODE'):
|
|
||||||
if class_.OUTPUT_NODE == True:
|
|
||||||
valid = False
|
|
||||||
try:
|
|
||||||
m = validate_inputs(prompt, x)
|
|
||||||
valid = m[0]
|
|
||||||
except:
|
|
||||||
valid = False
|
|
||||||
if valid:
|
|
||||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
if isinstance(e, comfy.model_management.InterruptProcessingException):
|
||||||
|
print("Processing interrupted")
|
||||||
|
else:
|
||||||
|
message = str(traceback.format_exc())
|
||||||
|
print(message)
|
||||||
|
if self.server.client_id is not None:
|
||||||
|
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
|
||||||
|
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
if (o not in current_outputs) and (o not in executed):
|
if (o not in current_outputs) and (o not in executed):
|
||||||
@ -210,14 +279,18 @@ class PromptExecutor:
|
|||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
|
||||||
|
|
||||||
|
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||||
gc.collect()
|
gc.collect()
|
||||||
comfy.model_management.soft_empty_cache()
|
comfy.model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item):
|
def validate_inputs(prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
|
if unique_id in validated:
|
||||||
|
return validated[unique_id]
|
||||||
|
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
@ -238,8 +311,9 @@ def validate_inputs(prompt, item):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
if r[val[1]] != type_input:
|
if r[val[1]] != type_input:
|
||||||
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
|
||||||
r = validate_inputs(prompt, o_id)
|
r = validate_inputs(prompt, o_id, validated)
|
||||||
if r[0] == False:
|
if r[0] == False:
|
||||||
|
validated[o_id] = r
|
||||||
return r
|
return r
|
||||||
else:
|
else:
|
||||||
if type_input == "INT":
|
if type_input == "INT":
|
||||||
@ -254,20 +328,25 @@ def validate_inputs(prompt, item):
|
|||||||
|
|
||||||
if len(info) > 1:
|
if len(info) > 1:
|
||||||
if "min" in info[1] and val < info[1]["min"]:
|
if "min" in info[1] and val < info[1]["min"]:
|
||||||
return (False, "Value smaller than min. {}, {}".format(class_type, x))
|
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
|
||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in info[1] and val > info[1]["max"]:
|
||||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
|
||||||
|
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
if ret != True:
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||||
return (False, "{}, {}".format(class_type, ret))
|
for r in ret:
|
||||||
|
if r != True:
|
||||||
|
return (False, "{}, {}".format(class_type, r))
|
||||||
else:
|
else:
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||||
return (True, "")
|
|
||||||
|
ret = (True, "")
|
||||||
|
validated[unique_id] = ret
|
||||||
|
return ret
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
outputs = set()
|
outputs = set()
|
||||||
@ -281,11 +360,12 @@ def validate_prompt(prompt):
|
|||||||
|
|
||||||
good_outputs = set()
|
good_outputs = set()
|
||||||
errors = []
|
errors = []
|
||||||
|
validated = {}
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
valid = False
|
valid = False
|
||||||
reason = ""
|
reason = ""
|
||||||
try:
|
try:
|
||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o, validated)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reason = m[1]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -294,7 +374,7 @@ def validate_prompt(prompt):
|
|||||||
reason = "Parsing error"
|
reason = "Parsing error"
|
||||||
|
|
||||||
if valid == True:
|
if valid == True:
|
||||||
good_outputs.add(x)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
print("Failed to validate prompt for output {} {}".format(o, reason))
|
print("Failed to validate prompt for output {} {}".format(o, reason))
|
||||||
print("output will be ignored")
|
print("output will be ignored")
|
||||||
@ -304,7 +384,7 @@ def validate_prompt(prompt):
|
|||||||
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
|
||||||
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
|
||||||
|
|
||||||
return (True, "")
|
return (True, "", list(good_outputs))
|
||||||
|
|
||||||
|
|
||||||
class PromptQueue:
|
class PromptQueue:
|
||||||
@ -340,8 +420,7 @@ class PromptQueue:
|
|||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
if "ui" in outputs[o]:
|
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||||
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
|
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
|
|
||||||
def get_current_queue(self):
|
def get_current_queue(self):
|
||||||
|
|||||||
@ -147,4 +147,37 @@ def get_filename_list(folder_name):
|
|||||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||||
return sorted(list(output_list))
|
return sorted(list(output_list))
|
||||||
|
|
||||||
|
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||||
|
def map_filename(filename):
|
||||||
|
prefix_len = len(os.path.basename(filename_prefix))
|
||||||
|
prefix = filename[:prefix_len + 1]
|
||||||
|
try:
|
||||||
|
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||||
|
except:
|
||||||
|
digits = 0
|
||||||
|
return (digits, prefix)
|
||||||
|
|
||||||
|
def compute_vars(input, image_width, image_height):
|
||||||
|
input = input.replace("%width%", str(image_width))
|
||||||
|
input = input.replace("%height%", str(image_height))
|
||||||
|
return input
|
||||||
|
|
||||||
|
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||||
|
|
||||||
|
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||||
|
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||||
|
|
||||||
|
full_output_folder = os.path.join(output_dir, subfolder)
|
||||||
|
|
||||||
|
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||||
|
print("Saving image outside the output folder is not allowed.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||||
|
except ValueError:
|
||||||
|
counter = 1
|
||||||
|
except FileNotFoundError:
|
||||||
|
os.makedirs(full_output_folder, exist_ok=True)
|
||||||
|
counter = 1
|
||||||
|
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||||
|
|||||||
4
main.py
4
main.py
@ -33,8 +33,8 @@ def prompt_worker(q, server):
|
|||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
item, item_id = q.get()
|
item, item_id = q.get()
|
||||||
e.execute(item[-2], item[-1])
|
e.execute(item[2], item[1], item[3], item[4])
|
||||||
q.task_done(item_id, e.outputs)
|
q.task_done(item_id, e.outputs_ui)
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||||
|
|||||||
225
nodes.py
225
nodes.py
@ -6,10 +6,12 @@ import json
|
|||||||
import hashlib
|
import hashlib
|
||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||||
@ -28,6 +30,7 @@ import importlib
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
def before_node_execution():
|
def before_node_execution():
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
@ -145,9 +148,6 @@ class ConditioningSetMask:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
@ -160,9 +160,6 @@ class VAEDecode:
|
|||||||
return (vae.decode(samples["samples"]), )
|
return (vae.decode(samples["samples"]), )
|
||||||
|
|
||||||
class VAEDecodeTiled:
|
class VAEDecodeTiled:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||||
@ -175,9 +172,6 @@ class VAEDecodeTiled:
|
|||||||
return (vae.decode_tiled(samples["samples"]), )
|
return (vae.decode_tiled(samples["samples"]), )
|
||||||
|
|
||||||
class VAEEncode:
|
class VAEEncode:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||||
@ -202,9 +196,6 @@ class VAEEncode:
|
|||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
class VAEEncodeTiled:
|
class VAEEncodeTiled:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
|
||||||
@ -219,9 +210,6 @@ class VAEEncodeTiled:
|
|||||||
return ({"samples":t}, )
|
return ({"samples":t}, )
|
||||||
|
|
||||||
class VAEEncodeForInpaint:
|
class VAEEncodeForInpaint:
|
||||||
def __init__(self, device="cpu"):
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
|
||||||
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
|
|||||||
|
|
||||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLatent:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT", ),
|
||||||
|
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
# support save metadata for latent sharing
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"prompt": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
file = f"{filename}_{counter:05}_.latent"
|
||||||
|
file = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
output["latent_tensor"] = samples["samples"]
|
||||||
|
|
||||||
|
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||||
|
return {"required": {"latent": [sorted(files), ]}, }
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, latent):
|
||||||
|
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||||
|
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||||
|
samples = {"samples": latent["latent_tensor"].float()}
|
||||||
|
return (samples, )
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, latent):
|
||||||
|
image_path = folder_paths.get_annotated_filepath(latent)
|
||||||
|
m = hashlib.sha256()
|
||||||
|
with open(image_path, 'rb') as f:
|
||||||
|
m.update(f.read())
|
||||||
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, latent):
|
||||||
|
if not folder_paths.exists_annotated_filepath(latent):
|
||||||
|
return "Invalid latent file: {}".format(latent)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class CheckpointLoader:
|
class CheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -296,7 +359,10 @@ class DiffusersLoader:
|
|||||||
paths = []
|
paths = []
|
||||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
if os.path.exists(search_path):
|
if os.path.exists(search_path):
|
||||||
paths += next(os.walk(search_path))[1]
|
for root, subdir, files in os.walk(search_path, followlinks=True):
|
||||||
|
if "model_index.json" in files:
|
||||||
|
paths.append(os.path.relpath(root, start=search_path))
|
||||||
|
|
||||||
return {"required": {"model_path": (paths,), }}
|
return {"required": {"model_path": (paths,), }}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
@ -306,9 +372,9 @@ class DiffusersLoader:
|
|||||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||||
if os.path.exists(search_path):
|
if os.path.exists(search_path):
|
||||||
paths = next(os.walk(search_path))[1]
|
path = os.path.join(search_path, model_path)
|
||||||
if model_path in paths:
|
if os.path.exists(path):
|
||||||
model_path = os.path.join(search_path, model_path)
|
model_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
@ -629,18 +695,57 @@ class LatentFromBatch:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return {"required": { "samples": ("LATENT",),
|
||||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||||
|
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "rotate"
|
FUNCTION = "frombatch"
|
||||||
|
|
||||||
CATEGORY = "latent"
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
def rotate(self, samples, batch_index):
|
def frombatch(self, samples, batch_index, length):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s_in = samples["samples"]
|
s_in = samples["samples"]
|
||||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||||
s["samples"] = s_in[batch_index:batch_index + 1].clone()
|
length = min(s_in.shape[0] - batch_index, length)
|
||||||
s["batch_index"] = batch_index
|
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||||
|
if "noise_mask" in samples:
|
||||||
|
masks = samples["noise_mask"]
|
||||||
|
if masks.shape[0] == 1:
|
||||||
|
s["noise_mask"] = masks.clone()
|
||||||
|
else:
|
||||||
|
if masks.shape[0] < s_in.shape[0]:
|
||||||
|
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||||
|
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
|
||||||
|
if "batch_index" not in s:
|
||||||
|
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
|
||||||
|
else:
|
||||||
|
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
class RepeatLatentBatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples": ("LATENT",),
|
||||||
|
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "repeat"
|
||||||
|
|
||||||
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
|
def repeat(self, samples, amount):
|
||||||
|
s = samples.copy()
|
||||||
|
s_in = samples["samples"]
|
||||||
|
|
||||||
|
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||||
|
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||||
|
masks = samples["noise_mask"]
|
||||||
|
if masks.shape[0] < s_in.shape[0]:
|
||||||
|
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
|
||||||
|
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
|
||||||
|
if "batch_index" in s:
|
||||||
|
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
|
||||||
|
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentUpscale:
|
class LatentUpscale:
|
||||||
@ -795,7 +900,7 @@ class SetLatentNoiseMask:
|
|||||||
|
|
||||||
def set_mask(self, samples, mask):
|
def set_mask(self, samples, mask):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s["noise_mask"] = mask
|
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||||
@ -805,8 +910,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if disable_noise:
|
if disable_noise:
|
||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
else:
|
else:
|
||||||
skip = latent["batch_index"] if "batch_index" in latent else 0
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||||
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
@ -901,39 +1006,7 @@ class SaveImage:
|
|||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
def map_filename(filename):
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
prefix_len = len(os.path.basename(filename_prefix))
|
|
||||||
prefix = filename[:prefix_len + 1]
|
|
||||||
try:
|
|
||||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
|
||||||
except:
|
|
||||||
digits = 0
|
|
||||||
return (digits, prefix)
|
|
||||||
|
|
||||||
def compute_vars(input):
|
|
||||||
input = input.replace("%width%", str(images[0].shape[1]))
|
|
||||||
input = input.replace("%height%", str(images[0].shape[0]))
|
|
||||||
return input
|
|
||||||
|
|
||||||
filename_prefix = compute_vars(filename_prefix)
|
|
||||||
|
|
||||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
|
||||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
|
||||||
|
|
||||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
|
||||||
|
|
||||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
|
||||||
print("Saving image outside the output folder is not allowed.")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
|
||||||
except ValueError:
|
|
||||||
counter = 1
|
|
||||||
except FileNotFoundError:
|
|
||||||
os.makedirs(full_output_folder, exist_ok=True)
|
|
||||||
counter = 1
|
|
||||||
|
|
||||||
results = list()
|
results = list()
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.cpu().numpy()
|
||||||
@ -984,6 +1057,7 @@ class LoadImage:
|
|||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
|
i = ImageOps.exif_transpose(i)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = torch.from_numpy(image)[None,]
|
image = torch.from_numpy(image)[None,]
|
||||||
@ -1027,6 +1101,7 @@ class LoadImageMask:
|
|||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
|
i = ImageOps.exif_transpose(i)
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
if i.getbands() != ("R", "G", "B", "A"):
|
||||||
i = i.convert("RGBA")
|
i = i.convert("RGBA")
|
||||||
mask = None
|
mask = None
|
||||||
@ -1170,6 +1245,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"EmptyLatentImage": EmptyLatentImage,
|
"EmptyLatentImage": EmptyLatentImage,
|
||||||
"LatentUpscale": LatentUpscale,
|
"LatentUpscale": LatentUpscale,
|
||||||
"LatentFromBatch": LatentFromBatch,
|
"LatentFromBatch": LatentFromBatch,
|
||||||
|
"RepeatLatentBatch": RepeatLatentBatch,
|
||||||
"SaveImage": SaveImage,
|
"SaveImage": SaveImage,
|
||||||
"PreviewImage": PreviewImage,
|
"PreviewImage": PreviewImage,
|
||||||
"LoadImage": LoadImage,
|
"LoadImage": LoadImage,
|
||||||
@ -1206,6 +1282,9 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
"CheckpointLoader": CheckpointLoader,
|
"CheckpointLoader": CheckpointLoader,
|
||||||
"DiffusersLoader": DiffusersLoader,
|
"DiffusersLoader": DiffusersLoader,
|
||||||
|
|
||||||
|
"LoadLatent": LoadLatent,
|
||||||
|
"SaveLatent": SaveLatent
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -1244,6 +1323,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"EmptyLatentImage": "Empty Latent Image",
|
"EmptyLatentImage": "Empty Latent Image",
|
||||||
"LatentUpscale": "Upscale Latent",
|
"LatentUpscale": "Upscale Latent",
|
||||||
"LatentComposite": "Latent Composite",
|
"LatentComposite": "Latent Composite",
|
||||||
|
"LatentFromBatch" : "Latent From Batch",
|
||||||
|
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||||
# Image
|
# Image
|
||||||
"SaveImage": "Save Image",
|
"SaveImage": "Save Image",
|
||||||
"PreviewImage": "Preview Image",
|
"PreviewImage": "Preview Image",
|
||||||
@ -1275,14 +1356,18 @@ def load_custom_node(module_path):
|
|||||||
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
|
||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||||
|
return False
|
||||||
|
|
||||||
def load_custom_nodes():
|
def load_custom_nodes():
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
|
node_import_times = []
|
||||||
for custom_node_path in node_paths:
|
for custom_node_path in node_paths:
|
||||||
possible_modules = os.listdir(custom_node_path)
|
possible_modules = os.listdir(custom_node_path)
|
||||||
if "__pycache__" in possible_modules:
|
if "__pycache__" in possible_modules:
|
||||||
@ -1291,11 +1376,25 @@ def load_custom_nodes():
|
|||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||||
load_custom_node(module_path)
|
if module_path.endswith(".disabled"): continue
|
||||||
|
time_before = time.perf_counter()
|
||||||
|
success = load_custom_node(module_path)
|
||||||
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
|
|
||||||
|
if len(node_import_times) > 0:
|
||||||
|
print("\nImport times for custom nodes:")
|
||||||
|
for n in sorted(node_import_times):
|
||||||
|
if n[2]:
|
||||||
|
import_message = ""
|
||||||
|
else:
|
||||||
|
import_message = " (IMPORT FAILED)"
|
||||||
|
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||||
|
print()
|
||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_nodes()
|
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
|
load_custom_nodes()
|
||||||
|
|||||||
@ -175,6 +175,8 @@
|
|||||||
"import threading\n",
|
"import threading\n",
|
||||||
"import time\n",
|
"import time\n",
|
||||||
"import socket\n",
|
"import socket\n",
|
||||||
|
"import urllib.request\n",
|
||||||
|
"\n",
|
||||||
"def iframe_thread(port):\n",
|
"def iframe_thread(port):\n",
|
||||||
" while True:\n",
|
" while True:\n",
|
||||||
" time.sleep(0.5)\n",
|
" time.sleep(0.5)\n",
|
||||||
@ -183,7 +185,9 @@
|
|||||||
" if result == 0:\n",
|
" if result == 0:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" sock.close()\n",
|
" sock.close()\n",
|
||||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||||
" for line in p.stdout:\n",
|
" for line in p.stdout:\n",
|
||||||
" print(line.decode(), end='')\n",
|
" print(line.decode(), end='')\n",
|
||||||
|
|||||||
74
server.py
74
server.py
@ -115,21 +115,23 @@ class PromptServer():
|
|||||||
|
|
||||||
def get_dir_by_type(dir_type):
|
def get_dir_by_type(dir_type):
|
||||||
if dir_type is None:
|
if dir_type is None:
|
||||||
type_dir = folder_paths.get_input_directory()
|
dir_type = "input"
|
||||||
elif dir_type == "input":
|
|
||||||
|
if dir_type == "input":
|
||||||
type_dir = folder_paths.get_input_directory()
|
type_dir = folder_paths.get_input_directory()
|
||||||
elif dir_type == "temp":
|
elif dir_type == "temp":
|
||||||
type_dir = folder_paths.get_temp_directory()
|
type_dir = folder_paths.get_temp_directory()
|
||||||
elif dir_type == "output":
|
elif dir_type == "output":
|
||||||
type_dir = folder_paths.get_output_directory()
|
type_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
return type_dir
|
return type_dir, dir_type
|
||||||
|
|
||||||
def image_upload(post, image_save_function=None):
|
def image_upload(post, image_save_function=None):
|
||||||
image = post.get("image")
|
image = post.get("image")
|
||||||
|
overwrite = post.get("overwrite")
|
||||||
|
|
||||||
image_upload_type = post.get("type")
|
image_upload_type = post.get("type")
|
||||||
upload_dir = get_dir_by_type(image_upload_type)
|
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||||||
|
|
||||||
if image and image.file:
|
if image and image.file:
|
||||||
filename = image.filename
|
filename = image.filename
|
||||||
@ -148,10 +150,14 @@ class PromptServer():
|
|||||||
split = os.path.splitext(filename)
|
split = os.path.splitext(filename)
|
||||||
filepath = os.path.join(full_output_folder, filename)
|
filepath = os.path.join(full_output_folder, filename)
|
||||||
|
|
||||||
i = 1
|
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||||||
while os.path.exists(filepath):
|
pass
|
||||||
filename = f"{split[0]} ({i}){split[1]}"
|
else:
|
||||||
i += 1
|
i = 1
|
||||||
|
while os.path.exists(filepath):
|
||||||
|
filename = f"{split[0]} ({i}){split[1]}"
|
||||||
|
filepath = os.path.join(full_output_folder, filename)
|
||||||
|
i += 1
|
||||||
|
|
||||||
if image_save_function is not None:
|
if image_save_function is not None:
|
||||||
image_save_function(image, post, filepath)
|
image_save_function(image, post, filepath)
|
||||||
@ -255,22 +261,34 @@ class PromptServer():
|
|||||||
async def get_prompt(request):
|
async def get_prompt(request):
|
||||||
return web.json_response(self.get_queue_info())
|
return web.json_response(self.get_queue_info())
|
||||||
|
|
||||||
|
def node_info(node_class):
|
||||||
|
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
info = {}
|
||||||
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
|
info['output'] = obj_class.RETURN_TYPES
|
||||||
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||||
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
|
info['name'] = node_class
|
||||||
|
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||||
|
info['description'] = ''
|
||||||
|
info['category'] = 'sd'
|
||||||
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
|
info['category'] = obj_class.CATEGORY
|
||||||
|
return info
|
||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
out[x] = node_info(x)
|
||||||
info = {}
|
return web.json_response(out)
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
@routes.get("/object_info/{node_class}")
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
async def get_object_info_node(request):
|
||||||
info['name'] = x
|
node_class = request.match_info.get("node_class", None)
|
||||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
out = {}
|
||||||
info['description'] = ''
|
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||||
info['category'] = 'sd'
|
out[node_class] = node_info(node_class)
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
|
||||||
info['category'] = obj_class.CATEGORY
|
|
||||||
out[x] = info
|
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
@ -312,13 +330,15 @@ class PromptServer():
|
|||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
|
prompt_id = str(uuid.uuid4())
|
||||||
|
outputs_to_execute = valid[2]
|
||||||
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||||
|
return web.json_response({"prompt_id": prompt_id})
|
||||||
else:
|
else:
|
||||||
resp_code = 400
|
|
||||||
out_string = valid[1]
|
|
||||||
print("invalid prompt:", valid[1])
|
print("invalid prompt:", valid[1])
|
||||||
|
return web.json_response({"error": valid[1]}, status=400)
|
||||||
return web.Response(body=out_string, status=resp_code)
|
else:
|
||||||
|
return web.json_response({"error": "no prompt"}, status=400)
|
||||||
|
|
||||||
@routes.post("/queue")
|
@routes.post("/queue")
|
||||||
async def post_queue(request):
|
async def post_queue(request):
|
||||||
@ -329,7 +349,7 @@ class PromptServer():
|
|||||||
if "delete" in json_data:
|
if "delete" in json_data:
|
||||||
to_delete = json_data['delete']
|
to_delete = json_data['delete']
|
||||||
for id_to_delete in to_delete:
|
for id_to_delete in to_delete:
|
||||||
delete_func = lambda a: a[1] == int(id_to_delete)
|
delete_func = lambda a: a[1] == id_to_delete
|
||||||
self.prompt_queue.delete_queue_item(delete_func)
|
self.prompt_queue.delete_queue_item(delete_func)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
@ -355,7 +375,7 @@ class PromptServer():
|
|||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root, follow_symlinks=True),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get_queue_info(self):
|
def get_queue_info(self):
|
||||||
|
|||||||
@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) {
|
|||||||
|
|
||||||
class MaskEditorDialog extends ComfyDialog {
|
class MaskEditorDialog extends ComfyDialog {
|
||||||
static instance = null;
|
static instance = null;
|
||||||
|
|
||||||
|
static getInstance() {
|
||||||
|
if(!MaskEditorDialog.instance) {
|
||||||
|
MaskEditorDialog.instance = new MaskEditorDialog(app);
|
||||||
|
}
|
||||||
|
|
||||||
|
return MaskEditorDialog.instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
is_layout_created = false;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super();
|
super();
|
||||||
this.element = $el("div.comfy-modal", { parent: document.body },
|
this.element = $el("div.comfy-modal", { parent: document.body },
|
||||||
[ $el("div.comfy-modal-content",
|
[ $el("div.comfy-modal-content",
|
||||||
[...this.createButtons()]),
|
[...this.createButtons()]),
|
||||||
]);
|
]);
|
||||||
MaskEditorDialog.instance = this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
createButtons() {
|
createButtons() {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
clearMask(self) {
|
|
||||||
}
|
|
||||||
|
|
||||||
createButton(name, callback) {
|
createButton(name, callback) {
|
||||||
var button = document.createElement("button");
|
var button = document.createElement("button");
|
||||||
button.innerText = name;
|
button.innerText = name;
|
||||||
button.addEventListener("click", callback);
|
button.addEventListener("click", callback);
|
||||||
return button;
|
return button;
|
||||||
}
|
}
|
||||||
|
|
||||||
createLeftButton(name, callback) {
|
createLeftButton(name, callback) {
|
||||||
var button = this.createButton(name, callback);
|
var button = this.createButton(name, callback);
|
||||||
button.style.cssFloat = "left";
|
button.style.cssFloat = "left";
|
||||||
button.style.marginRight = "4px";
|
button.style.marginRight = "4px";
|
||||||
return button;
|
return button;
|
||||||
}
|
}
|
||||||
|
|
||||||
createRightButton(name, callback) {
|
createRightButton(name, callback) {
|
||||||
var button = this.createButton(name, callback);
|
var button = this.createButton(name, callback);
|
||||||
button.style.cssFloat = "right";
|
button.style.cssFloat = "right";
|
||||||
button.style.marginLeft = "4px";
|
button.style.marginLeft = "4px";
|
||||||
return button;
|
return button;
|
||||||
}
|
}
|
||||||
|
|
||||||
createLeftSlider(self, name, callback) {
|
createLeftSlider(self, name, callback) {
|
||||||
const divElement = document.createElement('div');
|
const divElement = document.createElement('div');
|
||||||
divElement.id = "maskeditor-slider";
|
divElement.id = "maskeditor-slider";
|
||||||
@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
brush.style.MozBorderRadius = "50%";
|
brush.style.MozBorderRadius = "50%";
|
||||||
brush.style.WebkitBorderRadius = "50%";
|
brush.style.WebkitBorderRadius = "50%";
|
||||||
brush.style.position = "absolute";
|
brush.style.position = "absolute";
|
||||||
brush.style.zIndex = 100;
|
brush.style.zIndex = 8889;
|
||||||
brush.style.pointerEvents = "none";
|
brush.style.pointerEvents = "none";
|
||||||
this.brush = brush;
|
this.brush = brush;
|
||||||
this.element.appendChild(imgCanvas);
|
this.element.appendChild(imgCanvas);
|
||||||
@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||||
self.close();
|
self.close();
|
||||||
});
|
});
|
||||||
var saveButton = this.createRightButton("Save", () => {
|
|
||||||
|
this.saveButton = this.createRightButton("Save", () => {
|
||||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||||
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
|
||||||
self.save();
|
self.save();
|
||||||
@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
this.element.appendChild(bottom_panel);
|
this.element.appendChild(bottom_panel);
|
||||||
|
|
||||||
bottom_panel.appendChild(clearButton);
|
bottom_panel.appendChild(clearButton);
|
||||||
bottom_panel.appendChild(saveButton);
|
bottom_panel.appendChild(this.saveButton);
|
||||||
bottom_panel.appendChild(cancelButton);
|
bottom_panel.appendChild(cancelButton);
|
||||||
bottom_panel.appendChild(brush_size_slider);
|
bottom_panel.appendChild(brush_size_slider);
|
||||||
|
|
||||||
this.element.style.display = "block";
|
|
||||||
imgCanvas.style.position = "relative";
|
imgCanvas.style.position = "relative";
|
||||||
imgCanvas.style.top = "200";
|
imgCanvas.style.top = "200";
|
||||||
imgCanvas.style.left = "0";
|
imgCanvas.style.left = "0";
|
||||||
@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
show() {
|
show() {
|
||||||
// layout
|
if(!this.is_layout_created) {
|
||||||
const imgCanvas = document.createElement('canvas');
|
// layout
|
||||||
const maskCanvas = document.createElement('canvas');
|
const imgCanvas = document.createElement('canvas');
|
||||||
const backupCanvas = document.createElement('canvas');
|
const maskCanvas = document.createElement('canvas');
|
||||||
|
const backupCanvas = document.createElement('canvas');
|
||||||
|
|
||||||
imgCanvas.id = "imageCanvas";
|
imgCanvas.id = "imageCanvas";
|
||||||
maskCanvas.id = "maskCanvas";
|
maskCanvas.id = "maskCanvas";
|
||||||
backupCanvas.id = "backupCanvas";
|
backupCanvas.id = "backupCanvas";
|
||||||
|
|
||||||
this.setlayout(imgCanvas, maskCanvas);
|
this.setlayout(imgCanvas, maskCanvas);
|
||||||
|
|
||||||
// prepare content
|
// prepare content
|
||||||
this.maskCanvas = maskCanvas;
|
this.imgCanvas = imgCanvas;
|
||||||
this.backupCanvas = backupCanvas;
|
this.maskCanvas = maskCanvas;
|
||||||
this.maskCtx = maskCanvas.getContext('2d');
|
this.backupCanvas = backupCanvas;
|
||||||
this.backupCtx = backupCanvas.getContext('2d');
|
this.maskCtx = maskCanvas.getContext('2d');
|
||||||
|
this.backupCtx = backupCanvas.getContext('2d');
|
||||||
|
|
||||||
this.setImages(imgCanvas, backupCanvas);
|
this.setEventHandler(maskCanvas);
|
||||||
this.setEventHandler(maskCanvas);
|
|
||||||
|
this.is_layout_created = true;
|
||||||
|
|
||||||
|
// replacement of onClose hook since close is not real close
|
||||||
|
const self = this;
|
||||||
|
const observer = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutation) {
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||||
|
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||||
|
ComfyApp.onClipspaceEditorClosed();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.last_display_style = self.element.style.display;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const config = { attributes: true };
|
||||||
|
observer.observe(this.element, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace_return_node) {
|
||||||
|
this.saveButton.innerText = "Save to node";
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.saveButton.innerText = "Save";
|
||||||
|
}
|
||||||
|
this.saveButton.disabled = false;
|
||||||
|
|
||||||
|
this.element.style.display = "block";
|
||||||
|
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||||
|
}
|
||||||
|
|
||||||
|
isOpened() {
|
||||||
|
return this.element.style.display == "block";
|
||||||
}
|
}
|
||||||
|
|
||||||
setImages(imgCanvas, backupCanvas) {
|
setImages(imgCanvas, backupCanvas) {
|
||||||
@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
const maskCtx = this.maskCtx;
|
const maskCtx = this.maskCtx;
|
||||||
const maskCanvas = this.maskCanvas;
|
const maskCanvas = this.maskCanvas;
|
||||||
|
|
||||||
|
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||||
|
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||||
|
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||||
|
|
||||||
// image load
|
// image load
|
||||||
const orig_image = new Image();
|
const orig_image = new Image();
|
||||||
window.addEventListener("resize", () => {
|
window.addEventListener("resize", () => {
|
||||||
@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
rgb_url.searchParams.set('channel', 'rgb');
|
rgb_url.searchParams.set('channel', 'rgb');
|
||||||
orig_image.src = rgb_url;
|
orig_image.src = rgb_url;
|
||||||
this.image = orig_image;
|
this.image = orig_image;
|
||||||
}g
|
}
|
||||||
|
|
||||||
|
|
||||||
setEventHandler(maskCanvas) {
|
setEventHandler(maskCanvas) {
|
||||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||||
@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||||
} else if (event.key === '[') {
|
} else if (event.key === '[') {
|
||||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||||
|
} else if(event.key === 'Enter') {
|
||||||
|
self.save();
|
||||||
}
|
}
|
||||||
|
|
||||||
self.updateBrushPreview(self);
|
self.updateBrushPreview(self);
|
||||||
@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
save() {
|
async save() {
|
||||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||||
|
|
||||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||||
@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
formData.append('type', "input");
|
formData.append('type', "input");
|
||||||
formData.append('subfolder', "clipspace");
|
formData.append('subfolder', "clipspace");
|
||||||
|
|
||||||
uploadMask(item, formData);
|
this.saveButton.innerText = "Saving...";
|
||||||
|
this.saveButton.disabled = true;
|
||||||
|
await uploadMask(item, formData);
|
||||||
|
ComfyApp.onClipspaceEditorSave();
|
||||||
this.close();
|
this.close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.MaskEditor",
|
name: "Comfy.MaskEditor",
|
||||||
init(app) {
|
init(app) {
|
||||||
const callback =
|
ComfyApp.open_maskeditor =
|
||||||
function () {
|
function () {
|
||||||
let dlg = new MaskEditorDialog(app);
|
const dlg = MaskEditorDialog.getInstance();
|
||||||
dlg.show();
|
if(!dlg.isOpened()) {
|
||||||
|
dlg.show();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
|
||||||
ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback);
|
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -300,7 +300,7 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (widget.type === "number") {
|
if (widget.type === "number" || widget.type === "combo") {
|
||||||
addValueControlWidget(this, widget, "fixed");
|
addValueControlWidget(this, widget, "fixed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
|
|
||||||
//when clicked on top of a node
|
//when clicked on top of a node
|
||||||
//and it is not interactive
|
//and it is not interactive
|
||||||
if (node && this.allow_interaction && !skip_action && !this.read_only) {
|
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
|
||||||
if (!this.live_mode && !node.flags.pinned) {
|
if (!this.live_mode && !node.flags.pinned) {
|
||||||
this.bringToFront(node);
|
this.bringToFront(node);
|
||||||
} //if it wasn't selected?
|
} //if it wasn't selected?
|
||||||
|
|
||||||
//not dragging mouse to connect two slots
|
//not dragging mouse to connect two slots
|
||||||
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
|
||||||
//Search for corner for resize
|
//Search for corner for resize
|
||||||
if ( !skip_action &&
|
if ( !skip_action &&
|
||||||
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
|
||||||
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
}
|
}
|
||||||
|
|
||||||
//double clicking
|
//double clicking
|
||||||
if (is_double_click && this.selected_nodes[node.id]) {
|
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
|
||||||
//double click node
|
//double click node
|
||||||
if (node.onDblClick) {
|
if (node.onDblClick) {
|
||||||
node.onDblClick( e, pos, this );
|
node.onDblClick( e, pos, this );
|
||||||
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//get node over
|
||||||
|
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
||||||
|
|
||||||
if (this.dragging_rectangle)
|
if (this.dragging_rectangle)
|
||||||
{
|
{
|
||||||
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
|
||||||
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
this.ds.offset[1] += delta[1] / this.ds.scale;
|
this.ds.offset[1] += delta[1] / this.ds.scale;
|
||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
this.dirty_bgcanvas = true;
|
this.dirty_bgcanvas = true;
|
||||||
} else if (this.allow_interaction && !this.read_only) {
|
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
|
||||||
if (this.connecting_node) {
|
if (this.connecting_node) {
|
||||||
this.dirty_canvas = true;
|
this.dirty_canvas = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//get node over
|
|
||||||
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
|
|
||||||
|
|
||||||
//remove mouseover flag
|
//remove mouseover flag
|
||||||
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
|
||||||
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
|
||||||
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (show_text) {
|
if (show_text) {
|
||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "toggle":
|
case "toggle":
|
||||||
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.fill();
|
ctx.fill();
|
||||||
if (show_text) {
|
if (show_text) {
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
if (w.name != null) {
|
const label = w.label || w.name;
|
||||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
if (label != null) {
|
||||||
|
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||||
}
|
}
|
||||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.textAlign = "center";
|
ctx.textAlign = "center";
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.fillText(
|
ctx.fillText(
|
||||||
w.name + " " + Number(w.value).toFixed(3),
|
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||||
widget_width * 0.5,
|
widget_width * 0.5,
|
||||||
y + H * 0.7
|
y + H * 0.7
|
||||||
);
|
);
|
||||||
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
ctx.fill();
|
ctx.fill();
|
||||||
}
|
}
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
if (w.type == "number") {
|
if (w.type == "number") {
|
||||||
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
|
|
||||||
//ctx.stroke();
|
//ctx.stroke();
|
||||||
ctx.fillStyle = secondary_text_color;
|
ctx.fillStyle = secondary_text_color;
|
||||||
if (w.name != null) {
|
const label = w.label || w.name;
|
||||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
if (label != null) {
|
||||||
|
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||||
}
|
}
|
||||||
ctx.fillStyle = text_color;
|
ctx.fillStyle = text_color;
|
||||||
ctx.textAlign = "right";
|
ctx.textAlign = "right";
|
||||||
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
event,
|
event,
|
||||||
active_widget
|
active_widget
|
||||||
) {
|
) {
|
||||||
if (!node.widgets || !node.widgets.length) {
|
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
canvas.graph.add(group);
|
canvas.graph.add(group);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determines the furthest nodes in each direction
|
||||||
|
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
|
||||||
|
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||||
|
*/
|
||||||
|
LGraphCanvas.getBoundaryNodes = function(nodes) {
|
||||||
|
let top = null;
|
||||||
|
let right = null;
|
||||||
|
let bottom = null;
|
||||||
|
let left = null;
|
||||||
|
for (const nID in nodes) {
|
||||||
|
const node = nodes[nID];
|
||||||
|
const [x, y] = node.pos;
|
||||||
|
const [width, height] = node.size;
|
||||||
|
|
||||||
|
if (top === null || y < top.pos[1]) {
|
||||||
|
top = node;
|
||||||
|
}
|
||||||
|
if (right === null || x + width > right.pos[0] + right.size[0]) {
|
||||||
|
right = node;
|
||||||
|
}
|
||||||
|
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
|
||||||
|
bottom = node;
|
||||||
|
}
|
||||||
|
if (left === null || x < left.pos[0]) {
|
||||||
|
left = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"top": top,
|
||||||
|
"right": right,
|
||||||
|
"bottom": bottom,
|
||||||
|
"left": left
|
||||||
|
};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Determines the furthest nodes in each direction for the currently selected nodes
|
||||||
|
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
|
||||||
|
*/
|
||||||
|
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
|
||||||
|
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param {LGraphNode[]} nodes a list of nodes
|
||||||
|
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
|
||||||
|
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
|
||||||
|
*/
|
||||||
|
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
|
||||||
|
if (!nodes) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const canvas = LGraphCanvas.active_canvas;
|
||||||
|
let boundaryNodes = []
|
||||||
|
if (align_to === undefined) {
|
||||||
|
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
|
||||||
|
} else {
|
||||||
|
boundaryNodes = {
|
||||||
|
"top": align_to,
|
||||||
|
"right": align_to,
|
||||||
|
"bottom": align_to,
|
||||||
|
"left": align_to
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
|
||||||
|
switch (direction) {
|
||||||
|
case "right":
|
||||||
|
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
|
||||||
|
break;
|
||||||
|
case "left":
|
||||||
|
node.pos[0] = boundaryNodes["left"].pos[0];
|
||||||
|
break;
|
||||||
|
case "top":
|
||||||
|
node.pos[1] = boundaryNodes["top"].pos[1];
|
||||||
|
break;
|
||||||
|
case "bottom":
|
||||||
|
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
canvas.dirty_canvas = true;
|
||||||
|
canvas.dirty_bgcanvas = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
|
||||||
|
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||||
|
event: event,
|
||||||
|
callback: inner_clicked,
|
||||||
|
parentMenu: prev_menu,
|
||||||
|
});
|
||||||
|
|
||||||
|
function inner_clicked(value) {
|
||||||
|
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
|
||||||
|
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
|
||||||
|
event: event,
|
||||||
|
callback: inner_clicked,
|
||||||
|
parentMenu: prev_menu,
|
||||||
|
});
|
||||||
|
|
||||||
|
function inner_clicked(value) {
|
||||||
|
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
|
||||||
|
|
||||||
var canvas = LGraphCanvas.active_canvas;
|
var canvas = LGraphCanvas.active_canvas;
|
||||||
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
if (Object.keys(this.selected_nodes).length > 1) {
|
||||||
|
options.push({
|
||||||
|
content: "Align",
|
||||||
|
has_submenu: true,
|
||||||
|
callback: LGraphCanvas.onGroupAlign,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if (this._graph_stack && this._graph_stack.length > 0) {
|
if (this._graph_stack && this._graph_stack.length > 0) {
|
||||||
options.push(null, {
|
options.push(null, {
|
||||||
content: "Close subgraph",
|
content: "Close subgraph",
|
||||||
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
callback: LGraphCanvas.onMenuNodeToSubgraph
|
callback: LGraphCanvas.onMenuNodeToSubgraph
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (Object.keys(this.selected_nodes).length > 1) {
|
||||||
|
options.push({
|
||||||
|
content: "Align Selected To",
|
||||||
|
has_submenu: true,
|
||||||
|
callback: LGraphCanvas.onNodeAlign,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
options.push(null, {
|
options.push(null, {
|
||||||
content: "Remove",
|
content: "Remove",
|
||||||
disabled: !(node.removable !== false && !node.block_delete ),
|
disabled: !(node.removable !== false && !node.block_delete ),
|
||||||
|
|||||||
@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
|
|||||||
|
|
||||||
if (res.status !== 200) {
|
if (res.status !== 200) {
|
||||||
throw {
|
throw {
|
||||||
response: await res.text(),
|
response: await res.json(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
|
|||||||
import { ComfyUI, $el } from "./ui.js";
|
import { ComfyUI, $el } from "./ui.js";
|
||||||
import { api } from "./api.js";
|
import { api } from "./api.js";
|
||||||
import { defaultGraph } from "./defaultGraph.js";
|
import { defaultGraph } from "./defaultGraph.js";
|
||||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||||
@ -26,6 +26,8 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
static clipspace = null;
|
static clipspace = null;
|
||||||
static clipspace_invalidate_handler = null;
|
static clipspace_invalidate_handler = null;
|
||||||
|
static open_maskeditor = null;
|
||||||
|
static clipspace_return_node = null;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
@ -49,6 +51,114 @@ export class ComfyApp {
|
|||||||
this.shiftDown = false;
|
this.shiftDown = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static isImageNode(node) {
|
||||||
|
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static onClipspaceEditorSave() {
|
||||||
|
if(ComfyApp.clipspace_return_node) {
|
||||||
|
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static onClipspaceEditorClosed() {
|
||||||
|
ComfyApp.clipspace_return_node = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
static copyToClipspace(node) {
|
||||||
|
var widgets = null;
|
||||||
|
if(node.widgets) {
|
||||||
|
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||||
|
}
|
||||||
|
|
||||||
|
var imgs = undefined;
|
||||||
|
var orig_imgs = undefined;
|
||||||
|
if(node.imgs != undefined) {
|
||||||
|
imgs = [];
|
||||||
|
orig_imgs = [];
|
||||||
|
|
||||||
|
for (let i = 0; i < node.imgs.length; i++) {
|
||||||
|
imgs[i] = new Image();
|
||||||
|
imgs[i].src = node.imgs[i].src;
|
||||||
|
orig_imgs[i] = imgs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var selectedIndex = 0;
|
||||||
|
if(node.imageIndex) {
|
||||||
|
selectedIndex = node.imageIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComfyApp.clipspace = {
|
||||||
|
'widgets': widgets,
|
||||||
|
'imgs': imgs,
|
||||||
|
'original_imgs': orig_imgs,
|
||||||
|
'images': node.images,
|
||||||
|
'selectedIndex': selectedIndex,
|
||||||
|
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
||||||
|
};
|
||||||
|
|
||||||
|
ComfyApp.clipspace_return_node = null;
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace_invalidate_handler) {
|
||||||
|
ComfyApp.clipspace_invalidate_handler();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static pasteFromClipspace(node) {
|
||||||
|
if(ComfyApp.clipspace) {
|
||||||
|
// image paste
|
||||||
|
if(ComfyApp.clipspace.imgs && node.imgs) {
|
||||||
|
if(node.images && ComfyApp.clipspace.images) {
|
||||||
|
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||||
|
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.imgs) {
|
||||||
|
// deep-copy to cut link with clipspace
|
||||||
|
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
||||||
|
const img = new Image();
|
||||||
|
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
||||||
|
node.imgs = [img];
|
||||||
|
node.imageIndex = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const imgs = [];
|
||||||
|
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
||||||
|
imgs[i] = new Image();
|
||||||
|
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
||||||
|
node.imgs = imgs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.widgets) {
|
||||||
|
if(ComfyApp.clipspace.images) {
|
||||||
|
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||||
|
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||||
|
if(index >= 0) {
|
||||||
|
node.widgets[index].value = clip_image;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(ComfyApp.clipspace.widgets) {
|
||||||
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
|
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
|
if (prop && prop.type != 'button') {
|
||||||
|
prop.value = value;
|
||||||
|
prop.callback(value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.graph.setDirtyCanvas(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoke an extension callback
|
* Invoke an extension callback
|
||||||
* @param {keyof ComfyExtension} method The extension callback to execute
|
* @param {keyof ComfyExtension} method The extension callback to execute
|
||||||
@ -138,102 +248,30 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.push(
|
// prevent conflict of clipspace content
|
||||||
{
|
if(!ComfyApp.clipspace_return_node) {
|
||||||
content: "Copy (Clipspace)",
|
options.push({
|
||||||
callback: (obj) => {
|
content: "Copy (Clipspace)",
|
||||||
var widgets = null;
|
callback: (obj) => { ComfyApp.copyToClipspace(this); }
|
||||||
if(this.widgets) {
|
});
|
||||||
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
|
||||||
}
|
|
||||||
|
|
||||||
var imgs = undefined;
|
if(ComfyApp.clipspace != null) {
|
||||||
var orig_imgs = undefined;
|
options.push({
|
||||||
if(this.imgs != undefined) {
|
content: "Paste (Clipspace)",
|
||||||
imgs = [];
|
callback: () => { ComfyApp.pasteFromClipspace(this); }
|
||||||
orig_imgs = [];
|
});
|
||||||
|
}
|
||||||
|
|
||||||
for (let i = 0; i < this.imgs.length; i++) {
|
if(ComfyApp.isImageNode(this)) {
|
||||||
imgs[i] = new Image();
|
options.push({
|
||||||
imgs[i].src = this.imgs[i].src;
|
content: "Open in MaskEditor",
|
||||||
orig_imgs[i] = imgs[i];
|
callback: (obj) => {
|
||||||
|
ComfyApp.copyToClipspace(this);
|
||||||
|
ComfyApp.clipspace_return_node = this;
|
||||||
|
ComfyApp.open_maskeditor();
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
}
|
||||||
ComfyApp.clipspace = {
|
|
||||||
'widgets': widgets,
|
|
||||||
'imgs': imgs,
|
|
||||||
'original_imgs': orig_imgs,
|
|
||||||
'images': this.images,
|
|
||||||
'selectedIndex': 0,
|
|
||||||
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
|
|
||||||
};
|
|
||||||
|
|
||||||
if(ComfyApp.clipspace_invalidate_handler) {
|
|
||||||
ComfyApp.clipspace_invalidate_handler();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if(ComfyApp.clipspace != null) {
|
|
||||||
options.push(
|
|
||||||
{
|
|
||||||
content: "Paste (Clipspace)",
|
|
||||||
callback: () => {
|
|
||||||
if(ComfyApp.clipspace) {
|
|
||||||
// image paste
|
|
||||||
if(ComfyApp.clipspace.imgs && this.imgs) {
|
|
||||||
if(this.images && ComfyApp.clipspace.images) {
|
|
||||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
|
||||||
app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
|
|
||||||
|
|
||||||
}
|
|
||||||
else
|
|
||||||
app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(ComfyApp.clipspace.imgs) {
|
|
||||||
// deep-copy to cut link with clipspace
|
|
||||||
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
|
|
||||||
const img = new Image();
|
|
||||||
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
|
|
||||||
this.imgs = [img];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
const imgs = [];
|
|
||||||
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
|
|
||||||
imgs[i] = new Image();
|
|
||||||
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
|
|
||||||
this.imgs = imgs;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(this.widgets) {
|
|
||||||
if(ComfyApp.clipspace.images) {
|
|
||||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
|
||||||
const index = this.widgets.findIndex(obj => obj.name === 'image');
|
|
||||||
if(index >= 0) {
|
|
||||||
this.widgets[index].value = clip_image;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(ComfyApp.clipspace.widgets) {
|
|
||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
|
||||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
|
||||||
if (prop && prop.type != 'button') {
|
|
||||||
prop.value = value;
|
|
||||||
prop.callback(value);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
app.graph.setDirtyCanvas(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -864,7 +902,9 @@ export class ComfyApp {
|
|||||||
await this.#loadExtensions();
|
await this.#loadExtensions();
|
||||||
|
|
||||||
// Create and mount the LiteGraph in the DOM
|
// Create and mount the LiteGraph in the DOM
|
||||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
const mainCanvas = document.createElement("canvas")
|
||||||
|
mainCanvas.style.touchAction = "none"
|
||||||
|
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||||
canvasEl.tabIndex = "1";
|
canvasEl.tabIndex = "1";
|
||||||
document.body.prepend(canvasEl);
|
document.body.prepend(canvasEl);
|
||||||
|
|
||||||
@ -976,7 +1016,8 @@ export class ComfyApp {
|
|||||||
for (const o in nodeData["output"]) {
|
for (const o in nodeData["output"]) {
|
||||||
const output = nodeData["output"][o];
|
const output = nodeData["output"][o];
|
||||||
const outputName = nodeData["output_name"][o] || output;
|
const outputName = nodeData["output_name"][o] || output;
|
||||||
this.addOutput(outputName, output);
|
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||||
|
this.addOutput(outputName, output, { shape: outputShape });
|
||||||
}
|
}
|
||||||
|
|
||||||
const s = this.computeSize();
|
const s = this.computeSize();
|
||||||
@ -1237,7 +1278,7 @@ export class ComfyApp {
|
|||||||
try {
|
try {
|
||||||
await api.queuePrompt(number, p);
|
await api.queuePrompt(number, p);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.ui.dialog.show(error.response || error.toString());
|
this.ui.dialog.show(error.response.error || error.toString());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1283,6 +1324,11 @@ export class ComfyApp {
|
|||||||
this.loadGraphData(JSON.parse(reader.result));
|
this.loadGraphData(JSON.parse(reader.result));
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
|
} else if (file.name?.endsWith(".latent")) {
|
||||||
|
const info = await getLatentMetadata(file);
|
||||||
|
if (info.workflow) {
|
||||||
|
this.loadGraphData(JSON.parse(info.workflow));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,22 @@ export function getPngMetadata(file) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getLatentMetadata(file) {
|
||||||
|
return new Promise((r) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = (event) => {
|
||||||
|
const safetensorsData = new Uint8Array(event.target.result);
|
||||||
|
const dataView = new DataView(safetensorsData.buffer);
|
||||||
|
let header_size = dataView.getUint32(0, true);
|
||||||
|
let offset = 8;
|
||||||
|
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||||
|
r(header.__metadata__);
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.readAsArrayBuffer(file);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export async function importA1111(graph, parameters) {
|
export async function importA1111(graph, parameters) {
|
||||||
const p = parameters.lastIndexOf("\nSteps:");
|
const p = parameters.lastIndexOf("\nSteps:");
|
||||||
if (p > -1) {
|
if (p > -1) {
|
||||||
|
|||||||
@ -465,7 +465,7 @@ export class ComfyUI {
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png",
|
accept: ".json,image/png,.latent",
|
||||||
style: { display: "none" },
|
style: { display: "none" },
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
|
|||||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
|||||||
|
|
||||||
var v = valueControl.value;
|
var v = valueControl.value;
|
||||||
|
|
||||||
let min = targetWidget.options.min;
|
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||||
let max = targetWidget.options.max;
|
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||||
// limit to something that javascript can handle
|
let current_length = targetWidget.options.values.length;
|
||||||
max = Math.min(1125899906842624, max);
|
|
||||||
min = Math.max(-1125899906842624, min);
|
|
||||||
let range = (max - min) / (targetWidget.options.step / 10);
|
|
||||||
|
|
||||||
//adjust values based on valueControl Behaviour
|
switch (v) {
|
||||||
switch (v) {
|
case "increment":
|
||||||
case "fixed":
|
current_index += 1;
|
||||||
break;
|
break;
|
||||||
case "increment":
|
case "decrement":
|
||||||
targetWidget.value += targetWidget.options.step / 10;
|
current_index -= 1;
|
||||||
break;
|
break;
|
||||||
case "decrement":
|
case "randomize":
|
||||||
targetWidget.value -= targetWidget.options.step / 10;
|
current_index = Math.floor(Math.random() * current_length);
|
||||||
break;
|
default:
|
||||||
case "randomize":
|
break;
|
||||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
}
|
||||||
default:
|
current_index = Math.max(0, current_index);
|
||||||
break;
|
current_index = Math.min(current_length - 1, current_index);
|
||||||
|
if (current_index >= 0) {
|
||||||
|
let value = targetWidget.options.values[current_index];
|
||||||
|
targetWidget.value = value;
|
||||||
|
targetWidget.callback(value);
|
||||||
|
}
|
||||||
|
} else { //number
|
||||||
|
let min = targetWidget.options.min;
|
||||||
|
let max = targetWidget.options.max;
|
||||||
|
// limit to something that javascript can handle
|
||||||
|
max = Math.min(1125899906842624, max);
|
||||||
|
min = Math.max(-1125899906842624, min);
|
||||||
|
let range = (max - min) / (targetWidget.options.step / 10);
|
||||||
|
|
||||||
|
//adjust values based on valueControl Behaviour
|
||||||
|
switch (v) {
|
||||||
|
case "fixed":
|
||||||
|
break;
|
||||||
|
case "increment":
|
||||||
|
targetWidget.value += targetWidget.options.step / 10;
|
||||||
|
break;
|
||||||
|
case "decrement":
|
||||||
|
targetWidget.value -= targetWidget.options.step / 10;
|
||||||
|
break;
|
||||||
|
case "randomize":
|
||||||
|
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
/*check if values are over or under their respective
|
||||||
|
* ranges and set them to min or max.*/
|
||||||
|
if (targetWidget.value < min)
|
||||||
|
targetWidget.value = min;
|
||||||
|
|
||||||
|
if (targetWidget.value > max)
|
||||||
|
targetWidget.value = max;
|
||||||
}
|
}
|
||||||
/*check if values are over or under their respective
|
|
||||||
* ranges and set them to min or max.*/
|
|
||||||
if (targetWidget.value < min)
|
|
||||||
targetWidget.value = min;
|
|
||||||
|
|
||||||
if (targetWidget.value > max)
|
|
||||||
targetWidget.value = max;
|
|
||||||
}
|
}
|
||||||
return valueControl;
|
return valueControl;
|
||||||
};
|
};
|
||||||
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
|
|||||||
computeSize(node.size);
|
computeSize(node.size);
|
||||||
}
|
}
|
||||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||||
const t = ctx.getTransform();
|
|
||||||
const margin = 10;
|
const margin = 10;
|
||||||
|
const elRect = ctx.canvas.getBoundingClientRect();
|
||||||
|
const transform = new DOMMatrix()
|
||||||
|
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||||
|
.multiplySelf(ctx.getTransform())
|
||||||
|
.translateSelf(margin, margin + y);
|
||||||
|
|
||||||
Object.assign(this.inputEl.style, {
|
Object.assign(this.inputEl.style, {
|
||||||
left: `${t.a * margin + t.e}px`,
|
transformOrigin: "0 0",
|
||||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
transform: transform,
|
||||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
left: "0px",
|
||||||
background: (!node.color)?'':node.color,
|
top: "0px",
|
||||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
width: `${widgetWidth - (margin * 2)}px`,
|
||||||
|
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||||
position: "absolute",
|
position: "absolute",
|
||||||
|
background: (!node.color)?'':node.color,
|
||||||
color: (!node.color)?'':'white',
|
color: (!node.color)?'':'white',
|
||||||
zIndex: app.graph._nodes.indexOf(node),
|
zIndex: app.graph._nodes.indexOf(node),
|
||||||
fontSize: `${t.d * 10.0}px`,
|
|
||||||
});
|
});
|
||||||
this.inputEl.hidden = !visible;
|
this.inputEl.hidden = !visible;
|
||||||
},
|
},
|
||||||
|
|||||||
@ -39,6 +39,8 @@ body {
|
|||||||
padding: 2px;
|
padding: 2px;
|
||||||
resize: none;
|
resize: none;
|
||||||
border: none;
|
border: none;
|
||||||
|
box-sizing: border-box;
|
||||||
|
font-size: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.comfy-modal {
|
.comfy-modal {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user