Merge remote-tracking branch 'upstream/master' into addBatchIndex

This commit is contained in:
flyingshutter 2023-05-27 16:06:57 +02:00
commit abc3d0baf2
26 changed files with 1142 additions and 421 deletions

View File

@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised old_denoised = denoised
return x return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
return x+h return x+h
def slice_attention(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
scale = (int(q.shape[-1])**(-0.5))
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
return r1
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
# compute attention # compute attention
b,c,h,w = q.shape b,c,h,w = q.shape
scale = (int(c)**(-0.5))
q = q.reshape(b,c,h*w) q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w) v = v.reshape(b,c,h*w)
r1 = torch.zeros_like(k, device=q.device) r1 = slice_attention(q, k, v)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
h_ = r1.reshape(b,c,h,w) h_ = r1.reshape(b,c,h,w)
del r1 del r1
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x+h_
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
# compute attention # compute attention
B, C, H, W = q.shape B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v), (q, k, v),
) )
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = ( try:
out.unsqueeze(0) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
.reshape(B, 1, out.shape[1], C) out = out.transpose(2, 3).reshape(B, C, H, W)
.permute(0, 2, 1, 3) except model_management.OOM_EXCEPTION as e:
.reshape(B, out.shape[1], C) print("scaled_dot_product_attention OOMed: switched to slice attention")
) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x+out

View File

@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
""" """
B, N, _ = metric.shape B, N, _ = metric.shape
if r <= 0: if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

View File

@ -127,6 +127,32 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}") print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None current_loaded_model = None
current_gpu_controlnets = [] current_gpu_controlnets = []
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type

View File

@ -2,17 +2,26 @@ import torch
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import math import math
import numpy as np
def prepare_noise(latent_image, seed, skip=0): def prepare_noise(latent_image, seed, noise_inds=None):
""" """
creates random noise given a latent image and a seed. creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed optional arg skip can be used to skip and discard x number of noise generations for a given seed
""" """
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
for _ in range(skip): if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if i in unique_inds:
return noise noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions""" """ensures noise mask is of proper dimensions"""

View File

@ -6,6 +6,10 @@ import contextlib
from comfy import model_management from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if c1.keys() != c2.keys(): if c1.keys() != c2.keys():
return False return False
if 'c_crossattn' in c1: if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape: s1 = c1['c_crossattn'].shape
return False s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1: if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape: if c1['c_concat'].shape != c2['c_concat'].shape:
return False return False
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c_crossattn = [] c_crossattn = []
c_concat = [] c_concat = []
c_adm = [] c_adm = []
crossattn_max_len = 0
for x in c_list: for x in c_list:
if 'c_crossattn' in x: if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn']) c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x: if 'c_concat' in x:
c_concat.append(x['c_concat']) c_concat.append(x['c_concat'])
if 'c_adm' in x: if 'c_adm' in x:
c_adm.append(x['c_adm']) c_adm.append(x['c_adm'])
out = {} out = {}
if len(c_crossattn) > 0: c_crossattn_out = []
out['c_crossattn'] = [torch.cat(c_crossattn)] for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
if len(c_concat) > 0: if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)] out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0: if len(c_adm) > 0:
@ -471,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model self.model = model
@ -508,6 +532,8 @@ class KSampler:
if self.scheduler == "karras": if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal": elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple": elif self.scheduler == "simple":

View File

@ -581,12 +581,9 @@ class VAE:
samples = samples.cpu() samples = samples.cpu()
return samples return samples
def resize_image_to(tensor, target_latent_tensor, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
target_batch_size = target_latent_tensor.shape[0]
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
print(current_batch_size, target_batch_size) #print(current_batch_size, target_batch_size)
if current_batch_size == 1: if current_batch_size == 1:
return tensor return tensor
@ -623,7 +620,9 @@ class ControlNet:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_model.dtype == torch.float16: if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
@ -794,10 +793,14 @@ class T2IAdapter:
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.control_input = None
self.cond_hint = None self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1: if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(self.device) self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint) self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu() self.t2i_model.cpu()

View File

@ -72,7 +72,7 @@ class MaskToImage:
FUNCTION = "mask_to_image" FUNCTION = "mask_to_image"
def mask_to_image(self, mask): def mask_to_image(self, mask):
result = mask[None, :, :, None].expand(-1, -1, -1, 3) result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,) return (result,)
class ImageToMask: class ImageToMask:

View File

@ -59,6 +59,12 @@ class Blend:
def g(self, x): def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
class Blur: class Blur:
def __init__(self): def __init__(self):
pass pass
@ -88,12 +94,6 @@ class Blur:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0: if blur_radius == 0:
return (image,) return (image,)
@ -101,10 +101,11 @@ class Blur:
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1 kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1) blurred = blurred.permute(0, 2, 3, 1)
return (blurred,) return (blurred,)
@ -167,9 +168,15 @@ class Sharpen:
"max": 31, "max": 31,
"step": 1 "step": 1
}), }),
"alpha": ("FLOAT", { "sigma": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0, "max": 5.0,
"step": 0.1 "step": 0.1
}), }),
@ -181,21 +188,21 @@ class Sharpen:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
if sharpen_radius == 0: if sharpen_radius == 0:
return (image,) return (image,)
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1 kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
center = kernel_size // 2 center = kernel_size // 2
kernel[center, center] = kernel_size**2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
sharpened = sharpened.permute(0, 2, 3, 1) sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1) result = torch.clamp(sharpened, 0, 1)

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -17,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
out = model_loading.load_state_dict(sd).eval() out = model_loading.load_state_dict(sd).eval()
return (out, ) return (out, )

View File

@ -6,6 +6,7 @@ import threading
import heapq import heapq
import traceback import traceback
import gc import gc
import time
import torch import torch
import nodes import nodes
@ -26,21 +27,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj input_data_all[x] = obj
else: else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data input_data_all[x] = [input_data]
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
for x in h: for x in h:
if h[x] == "PROMPT": if h[x] == "PROMPT":
input_data_all[x] = prompt input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO": if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data: if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo'] input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID": if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id input_data_all[x] = [unique_id]
return input_data_all return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
@ -55,21 +117,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = unique_id server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def() obj = class_def()
nodes.before_node_execution() output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) outputs[unique_id] = output_data
if "ui" in outputs[unique_id]: if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
executed.add(unique_id) executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item):
@ -105,7 +166,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_data_all = get_input_data(inputs, class_def, unique_id, outputs) input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None: if input_data_all is not None:
try: try:
is_changed = class_def.IS_CHANGED(**input_data_all) #is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed prompt[unique_id]['is_changed'] = is_changed
except: except:
to_delete = True to_delete = True
@ -144,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor: class PromptExecutor:
def __init__(self, server): def __init__(self, server):
self.outputs = {} self.outputs = {}
self.outputs_ui = {}
self.old_prompt = {} self.old_prompt = {}
self.server = server self.server = server
def execute(self, prompt, extra_data={}): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
@ -155,6 +218,10 @@ class PromptExecutor:
else: else:
self.server.client_id = None self.server.client_id = None
execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them #delete cached outputs if nodes don't exist for them
to_delete = [] to_delete = []
@ -169,32 +236,34 @@ class PromptExecutor:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys()) current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set() executed = set()
try: try:
to_execute = [] to_execute = []
for x in prompt: for x in list(execute_outputs):
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] to_execute += [(0, x)]
if hasattr(class_, 'OUTPUT_NODE'):
to_execute += [(0, x)]
while len(to_execute) > 0: while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first #always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1] x = to_execute.pop(0)[-1]
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
if hasattr(class_, 'OUTPUT_NODE'):
if class_.OUTPUT_NODE == True:
valid = False
try:
m = validate_inputs(prompt, x)
valid = m[0]
except:
valid = False
if valid:
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
except Exception as e: except Exception as e:
print(traceback.format_exc()) if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted")
else:
message = str(traceback.format_exc())
print(message)
if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
to_delete = [] to_delete = []
for o in self.outputs: for o in self.outputs:
if (o not in current_outputs) and (o not in executed): if (o not in current_outputs) and (o not in executed):
@ -210,14 +279,18 @@ class PromptExecutor:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None self.server.last_node_id = None
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id) self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item): def validate_inputs(prompt, item, validated):
unique_id = item unique_id = item
if unique_id in validated:
return validated[unique_id]
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -238,8 +311,9 @@ def validate_inputs(prompt, item):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input: if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
r = validate_inputs(prompt, o_id) r = validate_inputs(prompt, o_id, validated)
if r[0] == False: if r[0] == False:
validated[o_id] = r
return r return r
else: else:
if type_input == "INT": if type_input == "INT":
@ -254,20 +328,25 @@ def validate_inputs(prompt, item):
if len(info) > 1: if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]: if "min" in info[1] and val < info[1]["min"]:
return (False, "Value smaller than min. {}, {}".format(class_type, x)) return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
if "max" in info[1] and val > info[1]["max"]: if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x)) return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all) #ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True: ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
return (False, "{}, {}".format(class_type, ret)) for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r))
else: else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")
ret = (True, "")
validated[unique_id] = ret
return ret
def validate_prompt(prompt): def validate_prompt(prompt):
outputs = set() outputs = set()
@ -281,11 +360,12 @@ def validate_prompt(prompt):
good_outputs = set() good_outputs = set()
errors = [] errors = []
validated = {}
for o in outputs: for o in outputs:
valid = False valid = False
reason = "" reason = ""
try: try:
m = validate_inputs(prompt, o) m = validate_inputs(prompt, o, validated)
valid = m[0] valid = m[0]
reason = m[1] reason = m[1]
except Exception as e: except Exception as e:
@ -294,7 +374,7 @@ def validate_prompt(prompt):
reason = "Parsing error" reason = "Parsing error"
if valid == True: if valid == True:
good_outputs.add(x) good_outputs.add(o)
else: else:
print("Failed to validate prompt for output {} {}".format(o, reason)) print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored") print("output will be ignored")
@ -304,7 +384,7 @@ def validate_prompt(prompt):
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
return (True, "") return (True, "", list(good_outputs))
class PromptQueue: class PromptQueue:
@ -340,8 +420,7 @@ class PromptQueue:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs: for o in outputs:
if "ui" in outputs[o]: self.history[prompt[1]]["outputs"][o] = outputs[o]
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
self.server.queue_updated() self.server.queue_updated()
def get_current_queue(self): def get_current_queue(self):

View File

@ -147,4 +147,37 @@ def get_filename_list(folder_name):
output_list.update(filter_files_extensions(recursive_search(x), folders[1])) output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list)) return sorted(list(output_list))
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input, image_width, image_height):
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix

View File

@ -33,8 +33,8 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:
item, item_id = q.get() item, item_id = q.get()
e.execute(item[-2], item[-1]) e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs) q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

225
nodes.py
View File

@ -6,10 +6,12 @@ import json
import hashlib import hashlib
import traceback import traceback
import math import math
import time
from PIL import Image from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -28,6 +30,7 @@ import importlib
import folder_paths import folder_paths
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -145,9 +148,6 @@ class ConditioningSetMask:
return (c, ) return (c, )
class VAEDecode: class VAEDecode:
def __init__(self, device="cpu"):
self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -160,9 +160,6 @@ class VAEDecode:
return (vae.decode(samples["samples"]), ) return (vae.decode(samples["samples"]), )
class VAEDecodeTiled: class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -175,9 +172,6 @@ class VAEDecodeTiled:
return (vae.decode_tiled(samples["samples"]), ) return (vae.decode_tiled(samples["samples"]), )
class VAEEncode: class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -202,9 +196,6 @@ class VAEEncode:
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -219,9 +210,6 @@ class VAEEncodeTiled:
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# support save metadata for latent sharing
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
safetensors.torch.save_file(output, file, metadata=metadata)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
if not folder_paths.exists_annotated_filepath(latent):
return "Invalid latent file: {}".format(latent)
return True
class CheckpointLoader: class CheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -296,7 +359,10 @@ class DiffusersLoader:
paths = [] paths = []
for search_path in folder_paths.get_folder_paths("diffusers"): for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path): if os.path.exists(search_path):
paths += next(os.walk(search_path))[1] for root, subdir, files in os.walk(search_path, followlinks=True):
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
return {"required": {"model_path": (paths,), }} return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -306,9 +372,9 @@ class DiffusersLoader:
def load_checkpoint(self, model_path, output_vae=True, output_clip=True): def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"): for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path): if os.path.exists(search_path):
paths = next(os.walk(search_path))[1] path = os.path.join(search_path, model_path)
if model_path in paths: if os.path.exists(path):
model_path = os.path.join(search_path, model_path) model_path = path
break break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
@ -629,18 +695,57 @@ class LatentFromBatch:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate" FUNCTION = "frombatch"
CATEGORY = "latent" CATEGORY = "latent/batch"
def rotate(self, samples, batch_index): def frombatch(self, samples, batch_index, length):
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index) batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone() length = min(s_in.shape[0] - batch_index, length)
s["batch_index"] = batch_index s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,) return (s,)
class LatentUpscale: class LatentUpscale:
@ -795,7 +900,7 @@ class SetLatentNoiseMask:
def set_mask(self, samples, mask): def set_mask(self, samples, mask):
s = samples.copy() s = samples.copy()
s["noise_mask"] = mask s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
@ -805,8 +910,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
skip = latent["batch_index"] if "batch_index" in latent else 0 batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, skip) noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -901,39 +1006,7 @@ class SaveImage:
CATEGORY = "image" CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def map_filename(filename): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input):
input = input.replace("%width%", str(images[0].shape[1]))
input = input.replace("%height%", str(images[0].shape[0]))
return input
filename_prefix = compute_vars(filename_prefix)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(self.output_dir, subfolder)
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
results = list() results = list()
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
@ -984,6 +1057,7 @@ class LoadImage:
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,] image = torch.from_numpy(image)[None,]
@ -1027,6 +1101,7 @@ class LoadImageMask:
def load_image(self, image, channel): def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"): if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA") i = i.convert("RGBA")
mask = None mask = None
@ -1170,6 +1245,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch, "LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage, "SaveImage": SaveImage,
"PreviewImage": PreviewImage, "PreviewImage": PreviewImage,
"LoadImage": LoadImage, "LoadImage": LoadImage,
@ -1206,6 +1282,9 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -1244,6 +1323,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image", "EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent", "LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite", "LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image
"SaveImage": "Save Image", "SaveImage": "Save Image",
"PreviewImage": "Preview Image", "PreviewImage": "Preview Image",
@ -1275,14 +1356,18 @@ def load_custom_node(module_path):
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else: else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e) print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes(): def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths: for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules: if "__pycache__" in possible_modules:
@ -1291,11 +1376,25 @@ def load_custom_nodes():
for possible_module in possible_modules: for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module) module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path) if module_path.endswith(".disabled"): continue
time_before = time.perf_counter()
success = load_custom_node(module_path)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes(): def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_nodes()

View File

@ -175,6 +175,8 @@
"import threading\n", "import threading\n",
"import time\n", "import time\n",
"import socket\n", "import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n", "def iframe_thread(port):\n",
" while True:\n", " while True:\n",
" time.sleep(0.5)\n", " time.sleep(0.5)\n",
@ -183,7 +185,9 @@
" if result == 0:\n", " if result == 0:\n",
" break\n", " break\n",
" sock.close()\n", " sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
"\n",
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
" for line in p.stdout:\n", " for line in p.stdout:\n",
" print(line.decode(), end='')\n", " print(line.decode(), end='')\n",

View File

@ -115,21 +115,23 @@ class PromptServer():
def get_dir_by_type(dir_type): def get_dir_by_type(dir_type):
if dir_type is None: if dir_type is None:
type_dir = folder_paths.get_input_directory() dir_type = "input"
elif dir_type == "input":
if dir_type == "input":
type_dir = folder_paths.get_input_directory() type_dir = folder_paths.get_input_directory()
elif dir_type == "temp": elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory() type_dir = folder_paths.get_temp_directory()
elif dir_type == "output": elif dir_type == "output":
type_dir = folder_paths.get_output_directory() type_dir = folder_paths.get_output_directory()
return type_dir return type_dir, dir_type
def image_upload(post, image_save_function=None): def image_upload(post, image_save_function=None):
image = post.get("image") image = post.get("image")
overwrite = post.get("overwrite")
image_upload_type = post.get("type") image_upload_type = post.get("type")
upload_dir = get_dir_by_type(image_upload_type) upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file: if image and image.file:
filename = image.filename filename = image.filename
@ -148,10 +150,14 @@ class PromptServer():
split = os.path.splitext(filename) split = os.path.splitext(filename)
filepath = os.path.join(full_output_folder, filename) filepath = os.path.join(full_output_folder, filename)
i = 1 if overwrite is not None and (overwrite == "true" or overwrite == "1"):
while os.path.exists(filepath): pass
filename = f"{split[0]} ({i}){split[1]}" else:
i += 1 i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
filepath = os.path.join(full_output_folder, filename)
i += 1
if image_save_function is not None: if image_save_function is not None:
image_save_function(image, post, filepath) image_save_function(image, post, filepath)
@ -255,22 +261,34 @@ class PromptServer():
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} out = {}
for x in nodes.NODE_CLASS_MAPPINGS: for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x] out[x] = node_info(x)
info = {} return web.json_response(out)
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES @routes.get("/object_info/{node_class}")
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] async def get_object_info_node(request):
info['name'] = x node_class = request.match_info.get("node_class", None)
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x out = {}
info['description'] = '' if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
info['category'] = 'sd' out[node_class] = node_info(node_class)
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
return web.json_response(out) return web.json_response(out)
@routes.get("/history") @routes.get("/history")
@ -312,13 +330,15 @@ class PromptServer():
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
if valid[0]: if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data)) prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id})
else: else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1]) print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1]}, status=400)
return web.Response(body=out_string, status=resp_code) else:
return web.json_response({"error": "no prompt"}, status=400)
@routes.post("/queue") @routes.post("/queue")
async def post_queue(request): async def post_queue(request):
@ -329,7 +349,7 @@ class PromptServer():
if "delete" in json_data: if "delete" in json_data:
to_delete = json_data['delete'] to_delete = json_data['delete']
for id_to_delete in to_delete: for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete) delete_func = lambda a: a[1] == id_to_delete
self.prompt_queue.delete_queue_item(delete_func) self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200) return web.Response(status=200)
@ -355,7 +375,7 @@ class PromptServer():
def add_routes(self): def add_routes(self):
self.app.add_routes(self.routes) self.app.add_routes(self.routes)
self.app.add_routes([ self.app.add_routes([
web.static('/', self.web_root), web.static('/', self.web_root, follow_symlinks=True),
]) ])
def get_queue_info(self): def get_queue_info(self):

View File

@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) {
class MaskEditorDialog extends ComfyDialog { class MaskEditorDialog extends ComfyDialog {
static instance = null; static instance = null;
static getInstance() {
if(!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog(app);
}
return MaskEditorDialog.instance;
}
is_layout_created = false;
constructor() { constructor() {
super(); super();
this.element = $el("div.comfy-modal", { parent: document.body }, this.element = $el("div.comfy-modal", { parent: document.body },
[ $el("div.comfy-modal-content", [ $el("div.comfy-modal-content",
[...this.createButtons()]), [...this.createButtons()]),
]); ]);
MaskEditorDialog.instance = this;
} }
createButtons() { createButtons() {
return []; return [];
} }
clearMask(self) {
}
createButton(name, callback) { createButton(name, callback) {
var button = document.createElement("button"); var button = document.createElement("button");
button.innerText = name; button.innerText = name;
button.addEventListener("click", callback); button.addEventListener("click", callback);
return button; return button;
} }
createLeftButton(name, callback) { createLeftButton(name, callback) {
var button = this.createButton(name, callback); var button = this.createButton(name, callback);
button.style.cssFloat = "left"; button.style.cssFloat = "left";
button.style.marginRight = "4px"; button.style.marginRight = "4px";
return button; return button;
} }
createRightButton(name, callback) { createRightButton(name, callback) {
var button = this.createButton(name, callback); var button = this.createButton(name, callback);
button.style.cssFloat = "right"; button.style.cssFloat = "right";
button.style.marginLeft = "4px"; button.style.marginLeft = "4px";
return button; return button;
} }
createLeftSlider(self, name, callback) { createLeftSlider(self, name, callback) {
const divElement = document.createElement('div'); const divElement = document.createElement('div');
divElement.id = "maskeditor-slider"; divElement.id = "maskeditor-slider";
@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog {
brush.style.MozBorderRadius = "50%"; brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute"; brush.style.position = "absolute";
brush.style.zIndex = 100; brush.style.zIndex = 8889;
brush.style.pointerEvents = "none"; brush.style.pointerEvents = "none";
this.brush = brush; this.brush = brush;
this.element.appendChild(imgCanvas); this.element.appendChild(imgCanvas);
@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog {
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.close(); self.close();
}); });
var saveButton = this.createRightButton("Save", () => {
this.saveButton = this.createRightButton("Save", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.save(); self.save();
@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(bottom_panel); this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton); bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(saveButton); bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider); bottom_panel.appendChild(brush_size_slider);
this.element.style.display = "block";
imgCanvas.style.position = "relative"; imgCanvas.style.position = "relative";
imgCanvas.style.top = "200"; imgCanvas.style.top = "200";
imgCanvas.style.left = "0"; imgCanvas.style.left = "0";
@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog {
} }
show() { show() {
// layout if(!this.is_layout_created) {
const imgCanvas = document.createElement('canvas'); // layout
const maskCanvas = document.createElement('canvas'); const imgCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas'); const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas"; imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas"; maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas"; backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas); this.setlayout(imgCanvas, maskCanvas);
// prepare content // prepare content
this.maskCanvas = maskCanvas; this.imgCanvas = imgCanvas;
this.backupCanvas = backupCanvas; this.maskCanvas = maskCanvas;
this.maskCtx = maskCanvas.getContext('2d'); this.backupCanvas = backupCanvas;
this.backupCtx = backupCanvas.getContext('2d'); this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.setImages(imgCanvas, backupCanvas); this.setEventHandler(maskCanvas);
this.setEventHandler(maskCanvas);
this.is_layout_created = true;
// replacement of onClose hook since close is not real close
const self = this;
const observer = new MutationObserver(function(mutations) {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
ComfyApp.onClipspaceEditorClosed();
}
self.last_display_style = self.element.style.display;
}
});
});
const config = { attributes: true };
observer.observe(this.element, config);
}
this.setImages(this.imgCanvas, this.backupCanvas);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
}
else {
this.saveButton.innerText = "Save";
}
this.saveButton.disabled = false;
this.element.style.display = "block";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
}
isOpened() {
return this.element.style.display == "block";
} }
setImages(imgCanvas, backupCanvas) { setImages(imgCanvas, backupCanvas) {
@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog {
const maskCtx = this.maskCtx; const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas; const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load // image load
const orig_image = new Image(); const orig_image = new Image();
window.addEventListener("resize", () => { window.addEventListener("resize", () => {
@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog {
rgb_url.searchParams.set('channel', 'rgb'); rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url; orig_image.src = rgb_url;
this.image = orig_image; this.image = orig_image;
}g }
setEventHandler(maskCanvas) { setEventHandler(maskCanvas) {
maskCanvas.addEventListener("contextmenu", (event) => { maskCanvas.addEventListener("contextmenu", (event) => {
@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog {
self.brush_size = Math.min(self.brush_size+2, 100); self.brush_size = Math.min(self.brush_size+2, 100);
} else if (event.key === '[') { } else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1); self.brush_size = Math.max(self.brush_size-2, 1);
} else if(event.key === 'Enter') {
self.save();
} }
self.updateBrushPreview(self); self.updateBrushPreview(self);
@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog {
} }
} }
save() { async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog {
formData.append('type', "input"); formData.append('type', "input");
formData.append('subfolder', "clipspace"); formData.append('subfolder', "clipspace");
uploadMask(item, formData); this.saveButton.innerText = "Saving...";
this.saveButton.disabled = true;
await uploadMask(item, formData);
ComfyApp.onClipspaceEditorSave();
this.close(); this.close();
} }
} }
@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog {
app.registerExtension({ app.registerExtension({
name: "Comfy.MaskEditor", name: "Comfy.MaskEditor",
init(app) { init(app) {
const callback = ComfyApp.open_maskeditor =
function () { function () {
let dlg = new MaskEditorDialog(app); const dlg = MaskEditorDialog.getInstance();
dlg.show(); if(!dlg.isOpened()) {
dlg.show();
}
}; };
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
} }
}); });

View File

@ -300,7 +300,7 @@ app.registerExtension({
} }
} }
if (widget.type === "number") { if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed"); addValueControlWidget(this, widget, "fixed");
} }

View File

@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
//when clicked on top of a node //when clicked on top of a node
//and it is not interactive //and it is not interactive
if (node && this.allow_interaction && !skip_action && !this.read_only) { if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
if (!this.live_mode && !node.flags.pinned) { if (!this.live_mode && !node.flags.pinned) {
this.bringToFront(node); this.bringToFront(node);
} //if it wasn't selected? } //if it wasn't selected?
//not dragging mouse to connect two slots //not dragging mouse to connect two slots
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize //Search for corner for resize
if ( !skip_action && if ( !skip_action &&
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
} }
//double clicking //double clicking
if (is_double_click && this.selected_nodes[node.id]) { if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
//double click node //double click node
if (node.onDblClick) { if (node.onDblClick) {
node.onDblClick( e, pos, this ); node.onDblClick( e, pos, this );
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
this.dirty_canvas = true; this.dirty_canvas = true;
} }
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
if (this.dragging_rectangle) if (this.dragging_rectangle)
{ {
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
this.ds.offset[1] += delta[1] / this.ds.scale; this.ds.offset[1] += delta[1] / this.ds.scale;
this.dirty_canvas = true; this.dirty_canvas = true;
this.dirty_bgcanvas = true; this.dirty_bgcanvas = true;
} else if (this.allow_interaction && !this.read_only) { } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
if (this.connecting_node) { if (this.connecting_node) {
this.dirty_canvas = true; this.dirty_canvas = true;
} }
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
//remove mouseover flag //remove mouseover flag
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
if (show_text) { if (show_text) {
ctx.textAlign = "center"; ctx.textAlign = "center";
ctx.fillStyle = text_color; ctx.fillStyle = text_color;
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
} }
break; break;
case "toggle": case "toggle":
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill(); ctx.fill();
if (show_text) { if (show_text) {
ctx.fillStyle = secondary_text_color; ctx.fillStyle = secondary_text_color;
if (w.name != null) { const label = w.label || w.name;
ctx.fillText(w.name, margin * 2, y + H * 0.7); if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
} }
ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.fillStyle = w.value ? text_color : secondary_text_color;
ctx.textAlign = "right"; ctx.textAlign = "right";
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center"; ctx.textAlign = "center";
ctx.fillStyle = text_color; ctx.fillStyle = text_color;
ctx.fillText( ctx.fillText(
w.name + " " + Number(w.value).toFixed(3), w.label || w.name + " " + Number(w.value).toFixed(3),
widget_width * 0.5, widget_width * 0.5,
y + H * 0.7 y + H * 0.7
); );
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill(); ctx.fill();
} }
ctx.fillStyle = secondary_text_color; ctx.fillStyle = secondary_text_color;
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillStyle = text_color; ctx.fillStyle = text_color;
ctx.textAlign = "right"; ctx.textAlign = "right";
if (w.type == "number") { if (w.type == "number") {
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
//ctx.stroke(); //ctx.stroke();
ctx.fillStyle = secondary_text_color; ctx.fillStyle = secondary_text_color;
if (w.name != null) { const label = w.label || w.name;
ctx.fillText(w.name, margin * 2, y + H * 0.7); if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
} }
ctx.fillStyle = text_color; ctx.fillStyle = text_color;
ctx.textAlign = "right"; ctx.textAlign = "right";
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
event, event,
active_widget active_widget
) { ) {
if (!node.widgets || !node.widgets.length) { if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
return null; return null;
} }
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
canvas.graph.add(group); canvas.graph.add(group);
}; };
/**
* Determines the furthest nodes in each direction
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.getBoundaryNodes = function(nodes) {
let top = null;
let right = null;
let bottom = null;
let left = null;
for (const nID in nodes) {
const node = nodes[nID];
const [x, y] = node.pos;
const [width, height] = node.size;
if (top === null || y < top.pos[1]) {
top = node;
}
if (right === null || x + width > right.pos[0] + right.size[0]) {
right = node;
}
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
bottom = node;
}
if (left === null || x < left.pos[0]) {
left = node;
}
}
return {
"top": top,
"right": right,
"bottom": bottom,
"left": left
};
}
/**
* Determines the furthest nodes in each direction for the currently selected nodes
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
}
/**
*
* @param {LGraphNode[]} nodes a list of nodes
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
*/
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
if (!nodes) {
return;
}
const canvas = LGraphCanvas.active_canvas;
let boundaryNodes = []
if (align_to === undefined) {
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
} else {
boundaryNodes = {
"top": align_to,
"right": align_to,
"bottom": align_to,
"left": align_to
}
}
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
switch (direction) {
case "right":
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
break;
case "left":
node.pos[0] = boundaryNodes["left"].pos[0];
break;
case "top":
node.pos[1] = boundaryNodes["top"].pos[1];
break;
case "bottom":
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
break;
}
}
canvas.dirty_canvas = true;
canvas.dirty_bgcanvas = true;
};
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
}
}
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
}
}
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
var canvas = LGraphCanvas.active_canvas; var canvas = LGraphCanvas.active_canvas;
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
}*/ }*/
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align",
has_submenu: true,
callback: LGraphCanvas.onGroupAlign,
})
}
if (this._graph_stack && this._graph_stack.length > 0) { if (this._graph_stack && this._graph_stack.length > 0) {
options.push(null, { options.push(null, {
content: "Close subgraph", content: "Close subgraph",
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
callback: LGraphCanvas.onMenuNodeToSubgraph callback: LGraphCanvas.onMenuNodeToSubgraph
}); });
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align Selected To",
has_submenu: true,
callback: LGraphCanvas.onNodeAlign,
})
}
options.push(null, { options.push(null, {
content: "Remove", content: "Remove",
disabled: !(node.removable !== false && !node.block_delete ), disabled: !(node.removable !== false && !node.block_delete ),

View File

@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
if (res.status !== 200) { if (res.status !== 200) {
throw { throw {
response: await res.text(), response: await res.json(),
}; };
} }
} }

View File

@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js"; import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js"; import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js"; import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js"; import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
/** /**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension * @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -26,6 +26,8 @@ export class ComfyApp {
*/ */
static clipspace = null; static clipspace = null;
static clipspace_invalidate_handler = null; static clipspace_invalidate_handler = null;
static open_maskeditor = null;
static clipspace_return_node = null;
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
@ -49,6 +51,114 @@ export class ComfyApp {
this.shiftDown = false; this.shiftDown = false;
} }
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
static onClipspaceEditorSave() {
if(ComfyApp.clipspace_return_node) {
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
}
}
static onClipspaceEditorClosed() {
ComfyApp.clipspace_return_node = null;
}
static copyToClipspace(node) {
var widgets = null;
if(node.widgets) {
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined;
var orig_imgs = undefined;
if(node.imgs != undefined) {
imgs = [];
orig_imgs = [];
for (let i = 0; i < node.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = node.imgs[i].src;
orig_imgs[i] = imgs[i];
}
}
var selectedIndex = 0;
if(node.imageIndex) {
selectedIndex = node.imageIndex;
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': node.images,
'selectedIndex': selectedIndex,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
ComfyApp.clipspace_return_node = null;
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
static pasteFromClipspace(node) {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
node.imgs = [img];
node.imageIndex = 0;
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
node.imgs = imgs;
}
}
}
}
if(node.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = node.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
node.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
app.graph.setDirtyCanvas(true);
}
}
/** /**
* Invoke an extension callback * Invoke an extension callback
* @param {keyof ComfyExtension} method The extension callback to execute * @param {keyof ComfyExtension} method The extension callback to execute
@ -138,102 +248,30 @@ export class ComfyApp {
} }
} }
options.push( // prevent conflict of clipspace content
{ if(!ComfyApp.clipspace_return_node) {
content: "Copy (Clipspace)", options.push({
callback: (obj) => { content: "Copy (Clipspace)",
var widgets = null; callback: (obj) => { ComfyApp.copyToClipspace(this); }
if(this.widgets) { });
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined; if(ComfyApp.clipspace != null) {
var orig_imgs = undefined; options.push({
if(this.imgs != undefined) { content: "Paste (Clipspace)",
imgs = []; callback: () => { ComfyApp.pasteFromClipspace(this); }
orig_imgs = []; });
}
for (let i = 0; i < this.imgs.length; i++) { if(ComfyApp.isImageNode(this)) {
imgs[i] = new Image(); options.push({
imgs[i].src = this.imgs[i].src; content: "Open in MaskEditor",
orig_imgs[i] = imgs[i]; callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
} }
} });
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': this.images,
'selectedIndex': 0,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
});
if(ComfyApp.clipspace != null) {
options.push(
{
content: "Paste (Clipspace)",
callback: () => {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && this.imgs) {
if(this.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
this.imgs = [img];
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
this.imgs = imgs;
}
}
}
}
if(this.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = this.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
this.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
}
app.graph.setDirtyCanvas(true);
}
}
);
} }
}; };
} }
@ -864,7 +902,9 @@ export class ComfyApp {
await this.#loadExtensions(); await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM // Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); const mainCanvas = document.createElement("canvas")
mainCanvas.style.touchAction = "none"
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
canvasEl.tabIndex = "1"; canvasEl.tabIndex = "1";
document.body.prepend(canvasEl); document.body.prepend(canvasEl);
@ -976,7 +1016,8 @@ export class ComfyApp {
for (const o in nodeData["output"]) { for (const o in nodeData["output"]) {
const output = nodeData["output"][o]; const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output; const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output); const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
} }
const s = this.computeSize(); const s = this.computeSize();
@ -1237,7 +1278,7 @@ export class ComfyApp {
try { try {
await api.queuePrompt(number, p); await api.queuePrompt(number, p);
} catch (error) { } catch (error) {
this.ui.dialog.show(error.response || error.toString()); this.ui.dialog.show(error.response.error || error.toString());
break; break;
} }
@ -1283,6 +1324,11 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result)); this.loadGraphData(JSON.parse(reader.result));
}; };
reader.readAsText(file); reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
}
} }
} }

View File

@ -47,6 +47,22 @@ export function getPngMetadata(file) {
}); });
} }
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const safetensorsData = new Uint8Array(event.target.result);
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
});
}
export async function importA1111(graph, parameters) { export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:"); const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) { if (p > -1) {

View File

@ -465,7 +465,7 @@ export class ComfyUI {
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png", accept: ".json,image/png,.latent",
style: { display: "none" }, style: { display: "none" },
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {

View File

@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
var v = valueControl.value; var v = valueControl.value;
let min = targetWidget.options.min; if (targetWidget.type == "combo" && v !== "fixed") {
let max = targetWidget.options.max; let current_index = targetWidget.options.values.indexOf(targetWidget.value);
// limit to something that javascript can handle let current_length = targetWidget.options.values.length;
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour switch (v) {
switch (v) { case "increment":
case "fixed": current_index += 1;
break; break;
case "increment": case "decrement":
targetWidget.value += targetWidget.options.step / 10; current_index -= 1;
break; break;
case "decrement": case "randomize":
targetWidget.value -= targetWidget.options.step / 10; current_index = Math.floor(Math.random() * current_length);
break; default:
case "randomize": break;
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; }
default: current_index = Math.max(0, current_index);
break; current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
} }
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
} }
return valueControl; return valueControl;
}; };
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
computeSize(node.size); computeSize(node.size);
} }
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const t = ctx.getTransform();
const margin = 10; const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
Object.assign(this.inputEl.style, { Object.assign(this.inputEl.style, {
left: `${t.a * margin + t.e}px`, transformOrigin: "0 0",
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, transform: transform,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, left: "0px",
background: (!node.color)?'':node.color, top: "0px",
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute", position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white', color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node), zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`,
}); });
this.inputEl.hidden = !visible; this.inputEl.hidden = !visible;
}, },

View File

@ -39,6 +39,8 @@ body {
padding: 2px; padding: 2px;
resize: none; resize: none;
border: none; border: none;
box-sizing: border-box;
font-size: 10px;
} }
.comfy-modal { .comfy-modal {