Merge remote-tracking branch 'upstream/master' into addBatchIndex

This commit is contained in:
flyingshutter 2023-05-27 16:06:57 +02:00
commit abc3d0baf2
26 changed files with 1142 additions and 421 deletions

View File

@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
return x+h
def slice_attention(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
scale = (int(q.shape[-1])**(-0.5))
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
return r1
class AttnBlock(nn.Module):
def __init__(self, in_channels):
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
# compute attention
b,c,h,w = q.shape
scale = (int(c)**(-0.5))
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = torch.zeros_like(k, device=q.device)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_)
return x+h_
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out

View File

@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
"""
B, N, _ = metric.shape
if r <= 0:
if r <= 0 or w == 1 or h == 1:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

View File

@ -127,6 +127,32 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None
current_gpu_controlnets = []
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type

View File

@ -2,17 +2,26 @@ import torch
import comfy.model_management
import comfy.samplers
import math
import numpy as np
def prepare_noise(latent_image, seed, skip=0):
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
for _ in range(skip):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""

View File

@ -6,6 +6,10 @@ import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
if len(c_adm) > 0:
@ -471,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
@ -508,6 +532,8 @@ class KSampler:
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":

View File

@ -581,12 +581,9 @@ class VAE:
samples = samples.cpu()
return samples
def resize_image_to(tensor, target_latent_tensor, batched_number):
tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center")
target_batch_size = target_latent_tensor.shape[0]
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
print(current_batch_size, target_batch_size)
#print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -623,7 +620,9 @@ class ControlNet:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
@ -794,10 +793,14 @@ class T2IAdapter:
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.control_input = None
self.cond_hint = None
self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu()

View File

@ -72,7 +72,7 @@ class MaskToImage:
FUNCTION = "mask_to_image"
def mask_to_image(self, mask):
result = mask[None, :, :, None].expand(-1, -1, -1, 3)
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,)
class ImageToMask:

View File

@ -59,6 +59,12 @@ class Blend:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
class Blur:
def __init__(self):
pass
@ -88,12 +94,6 @@ class Blur:
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
@ -101,10 +101,11 @@ class Blur:
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
@ -167,9 +168,15 @@ class Sharpen:
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.1
}),
@ -181,21 +188,21 @@ class Sharpen:
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -17,7 +17,7 @@ class UpscaleModelLoader:
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
out = model_loading.load_state_dict(sd).eval()
return (out, )

View File

@ -6,6 +6,7 @@ import threading
import heapq
import traceback
import gc
import time
import torch
import nodes
@ -26,21 +27,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
input_data_all[x] = [unique_id]
return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
@ -55,21 +117,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
nodes.before_node_execution()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
if "ui" in outputs[unique_id]:
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item):
@ -105,7 +166,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
is_changed = class_def.IS_CHANGED(**input_data_all)
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
@ -144,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
self.outputs_ui = {}
self.old_prompt = {}
self.server = server
def execute(self, prompt, extra_data={}):
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@ -155,6 +218,10 @@ class PromptExecutor:
else:
self.server.client_id = None
execution_start_time = time.perf_counter()
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
@ -169,32 +236,34 @@ class PromptExecutor:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set()
try:
to_execute = []
for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
to_execute += [(0, x)]
for x in list(execute_outputs):
to_execute += [(0, x)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
if hasattr(class_, 'OUTPUT_NODE'):
if class_.OUTPUT_NODE == True:
valid = False
try:
m = validate_inputs(prompt, x)
valid = m[0]
except:
valid = False
if valid:
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui)
except Exception as e:
print(traceback.format_exc())
if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted")
else:
message = str(traceback.format_exc())
print(message)
if self.server.client_id is not None:
self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id)
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
@ -210,14 +279,18 @@ class PromptExecutor:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id)
self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()
def validate_inputs(prompt, item):
def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
return validated[unique_id]
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -238,8 +311,9 @@ def validate_inputs(prompt, item):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
if r[val[1]] != type_input:
return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input))
r = validate_inputs(prompt, o_id)
r = validate_inputs(prompt, o_id, validated)
if r[0] == False:
validated[o_id] = r
return r
else:
if type_input == "INT":
@ -254,20 +328,25 @@ def validate_inputs(prompt, item):
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
return (False, "Value smaller than min. {}, {}".format(class_type, x))
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x))
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r))
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "")
ret = (True, "")
validated[unique_id] = ret
return ret
def validate_prompt(prompt):
outputs = set()
@ -281,11 +360,12 @@ def validate_prompt(prompt):
good_outputs = set()
errors = []
validated = {}
for o in outputs:
valid = False
reason = ""
try:
m = validate_inputs(prompt, o)
m = validate_inputs(prompt, o, validated)
valid = m[0]
reason = m[1]
except Exception as e:
@ -294,7 +374,7 @@ def validate_prompt(prompt):
reason = "Parsing error"
if valid == True:
good_outputs.add(x)
good_outputs.add(o)
else:
print("Failed to validate prompt for output {} {}".format(o, reason))
print("output will be ignored")
@ -304,7 +384,7 @@ def validate_prompt(prompt):
errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors)))
return (False, "Prompt has no properly connected outputs\n {}".format(errors_list))
return (True, "")
return (True, "", list(good_outputs))
class PromptQueue:
@ -340,8 +420,7 @@ class PromptQueue:
prompt = self.currently_running.pop(item_id)
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
if "ui" in outputs[o]:
self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"]
self.history[prompt[1]]["outputs"][o] = outputs[o]
self.server.queue_updated()
def get_current_queue(self):

View File

@ -147,4 +147,37 @@ def get_filename_list(folder_name):
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list))
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input, image_width, image_height):
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix

View File

@ -33,8 +33,8 @@ def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs)
e.execute(item[2], item[1], item[3], item[4])
q.task_done(item_id, e.outputs_ui)
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

225
nodes.py
View File

@ -6,10 +6,12 @@ import json
import hashlib
import traceback
import math
import time
from PIL import Image
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -28,6 +30,7 @@ import importlib
import folder_paths
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@ -145,9 +148,6 @@ class ConditioningSetMask:
return (c, )
class VAEDecode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -160,9 +160,6 @@ class VAEDecode:
return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
@ -175,9 +172,6 @@ class VAEDecodeTiled:
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -202,9 +196,6 @@ class VAEEncode:
return ({"samples":t}, )
class VAEEncodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
@ -219,9 +210,6 @@ class VAEEncodeTiled:
return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
@ -260,6 +248,81 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# support save metadata for latent sharing
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
safetensors.torch.save_file(output, file, metadata=metadata)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
if not folder_paths.exists_annotated_filepath(latent):
return "Invalid latent file: {}".format(latent)
return True
class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
@ -296,7 +359,10 @@ class DiffusersLoader:
paths = []
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths += next(os.walk(search_path))[1]
for root, subdir, files in os.walk(search_path, followlinks=True):
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
return {"required": {"model_path": (paths,), }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -306,9 +372,9 @@ class DiffusersLoader:
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
paths = next(os.walk(search_path))[1]
if model_path in paths:
model_path = os.path.join(search_path, model_path)
path = os.path.join(search_path, model_path)
if os.path.exists(path):
model_path = path
break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
@ -629,18 +695,57 @@ class LatentFromBatch:
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
FUNCTION = "frombatch"
CATEGORY = "latent"
CATEGORY = "latent/batch"
def rotate(self, samples, batch_index):
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone()
s["batch_index"] = batch_index
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,)
class LatentUpscale:
@ -795,7 +900,7 @@ class SetLatentNoiseMask:
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
@ -805,8 +910,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
skip = latent["batch_index"] if "batch_index" in latent else 0
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@ -901,39 +1006,7 @@ class SaveImage:
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input):
input = input.replace("%width%", str(images[0].shape[1]))
input = input.replace("%height%", str(images[0].shape[0]))
return input
filename_prefix = compute_vars(filename_prefix)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(self.output_dir, subfolder)
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -984,6 +1057,7 @@ class LoadImage:
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
@ -1027,6 +1101,7 @@ class LoadImageMask:
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA")
mask = None
@ -1170,6 +1245,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
@ -1206,6 +1282,9 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1244,6 +1323,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
@ -1275,14 +1356,18 @@ def load_custom_node(module_path):
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules:
@ -1291,11 +1376,25 @@ def load_custom_nodes():
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
if module_path.endswith(".disabled"): continue
time_before = time.perf_counter()
success = load_custom_node(module_path)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_nodes()

View File

@ -175,6 +175,8 @@
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
@ -183,7 +185,9 @@
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
"\n",
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
" for line in p.stdout:\n",
" print(line.decode(), end='')\n",

View File

@ -81,7 +81,7 @@ class PromptServer():
# Reusing existing session, remove old
self.sockets.pop(sid, None)
else:
sid = uuid.uuid4().hex
sid = uuid.uuid4().hex
self.sockets[sid] = ws
@ -115,21 +115,23 @@ class PromptServer():
def get_dir_by_type(dir_type):
if dir_type is None:
type_dir = folder_paths.get_input_directory()
elif dir_type == "input":
dir_type = "input"
if dir_type == "input":
type_dir = folder_paths.get_input_directory()
elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
return type_dir
return type_dir, dir_type
def image_upload(post, image_save_function=None):
image = post.get("image")
overwrite = post.get("overwrite")
image_upload_type = post.get("type")
upload_dir = get_dir_by_type(image_upload_type)
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file:
filename = image.filename
@ -148,10 +150,14 @@ class PromptServer():
split = os.path.splitext(filename)
filepath = os.path.join(full_output_folder, filename)
i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
i += 1
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
pass
else:
i = 1
while os.path.exists(filepath):
filename = f"{split[0]} ({i}){split[1]}"
filepath = os.path.join(full_output_folder, filename)
i += 1
if image_save_function is not None:
image_save_function(image, post, filepath)
@ -255,22 +261,34 @@ class PromptServer():
async def get_prompt(request):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
out[x] = node_info(x)
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/history")
@ -312,14 +330,16 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
self.prompt_queue.put((number, id(prompt), prompt, extra_data))
prompt_id = str(uuid.uuid4())
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id})
else:
resp_code = 400
out_string = valid[1]
print("invalid prompt:", valid[1])
return web.json_response({"error": valid[1]}, status=400)
else:
return web.json_response({"error": "no prompt"}, status=400)
return web.Response(body=out_string, status=resp_code)
@routes.post("/queue")
async def post_queue(request):
json_data = await request.json()
@ -329,9 +349,9 @@ class PromptServer():
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
delete_func = lambda a: a[1] == int(id_to_delete)
delete_func = lambda a: a[1] == id_to_delete
self.prompt_queue.delete_queue_item(delete_func)
return web.Response(status=200)
@routes.post("/interrupt")
@ -355,7 +375,7 @@ class PromptServer():
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root),
web.static('/', self.web_root, follow_symlinks=True),
])
def get_queue_info(self):

View File

@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) {
class MaskEditorDialog extends ComfyDialog {
static instance = null;
static getInstance() {
if(!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog(app);
}
return MaskEditorDialog.instance;
}
is_layout_created = false;
constructor() {
super();
this.element = $el("div.comfy-modal", { parent: document.body },
[ $el("div.comfy-modal-content",
[...this.createButtons()]),
]);
MaskEditorDialog.instance = this;
}
createButtons() {
return [];
}
clearMask(self) {
}
createButton(name, callback) {
var button = document.createElement("button");
button.innerText = name;
button.addEventListener("click", callback);
return button;
}
createLeftButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "left";
button.style.marginRight = "4px";
return button;
}
createRightButton(name, callback) {
var button = this.createButton(name, callback);
button.style.cssFloat = "right";
button.style.marginLeft = "4px";
return button;
}
createLeftSlider(self, name, callback) {
const divElement = document.createElement('div');
divElement.id = "maskeditor-slider";
@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog {
brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute";
brush.style.zIndex = 100;
brush.style.zIndex = 8889;
brush.style.pointerEvents = "none";
this.brush = brush;
this.element.appendChild(imgCanvas);
@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog {
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.close();
});
var saveButton = this.createRightButton("Save", () => {
this.saveButton = this.createRightButton("Save", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown);
self.save();
@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(saveButton);
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider);
this.element.style.display = "block";
imgCanvas.style.position = "relative";
imgCanvas.style.top = "200";
imgCanvas.style.left = "0";
@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog {
}
show() {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
if(!this.is_layout_created) {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas);
this.setlayout(imgCanvas, maskCanvas);
// prepare content
this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas;
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
// prepare content
this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas;
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.setImages(imgCanvas, backupCanvas);
this.setEventHandler(maskCanvas);
this.setEventHandler(maskCanvas);
this.is_layout_created = true;
// replacement of onClose hook since close is not real close
const self = this;
const observer = new MutationObserver(function(mutations) {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
ComfyApp.onClipspaceEditorClosed();
}
self.last_display_style = self.element.style.display;
}
});
});
const config = { attributes: true };
observer.observe(this.element, config);
}
this.setImages(this.imgCanvas, this.backupCanvas);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
}
else {
this.saveButton.innerText = "Save";
}
this.saveButton.disabled = false;
this.element.style.display = "block";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
}
isOpened() {
return this.element.style.display == "block";
}
setImages(imgCanvas, backupCanvas) {
@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog {
const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load
const orig_image = new Image();
window.addEventListener("resize", () => {
@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog {
rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url;
this.image = orig_image;
}g
}
setEventHandler(maskCanvas) {
maskCanvas.addEventListener("contextmenu", (event) => {
@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog {
self.brush_size = Math.min(self.brush_size+2, 100);
} else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1);
} else if(event.key === 'Enter') {
self.save();
}
self.updateBrushPreview(self);
@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog {
}
}
save() {
async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog {
formData.append('type', "input");
formData.append('subfolder', "clipspace");
uploadMask(item, formData);
this.saveButton.innerText = "Saving...";
this.saveButton.disabled = true;
await uploadMask(item, formData);
ComfyApp.onClipspaceEditorSave();
this.close();
}
}
@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog {
app.registerExtension({
name: "Comfy.MaskEditor",
init(app) {
const callback =
ComfyApp.open_maskeditor =
function () {
let dlg = new MaskEditorDialog(app);
dlg.show();
const dlg = MaskEditorDialog.getInstance();
if(!dlg.isOpened()) {
dlg.show();
}
};
const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0
ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback);
ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor);
}
});

View File

@ -300,7 +300,7 @@ app.registerExtension({
}
}
if (widget.type === "number") {
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
}

View File

@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action)
//when clicked on top of a node
//and it is not interactive
if (node && this.allow_interaction && !skip_action && !this.read_only) {
if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) {
if (!this.live_mode && !node.flags.pinned) {
this.bringToFront(node);
} //if it wasn't selected?
//not dragging mouse to connect two slots
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize
if ( !skip_action &&
node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action)
}
//double clicking
if (is_double_click && this.selected_nodes[node.id]) {
if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) {
//double click node
if (node.onDblClick) {
node.onDblClick( e, pos, this );
@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action)
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
if (this.dragging_rectangle)
{
this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0];
@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action)
this.ds.offset[1] += delta[1] / this.ds.scale;
this.dirty_canvas = true;
this.dirty_bgcanvas = true;
} else if (this.allow_interaction && !this.read_only) {
} else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) {
if (this.connecting_node) {
this.dirty_canvas = true;
}
//get node over
var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes);
//remove mouseover flag
for (var i = 0, l = this.graph._nodes.length; i < l; ++i) {
if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) {
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
if (show_text) {
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
}
break;
case "toggle":
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
if (show_text) {
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = w.value ? text_color : secondary_text_color;
ctx.textAlign = "right";
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(
w.name + " " + Number(w.value).toFixed(3),
w.label || w.name + " " + Number(w.value).toFixed(3),
widget_width * 0.5,
y + H * 0.7
);
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
}
ctx.fillStyle = secondary_text_color;
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillStyle = text_color;
ctx.textAlign = "right";
if (w.type == "number") {
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
//ctx.stroke();
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = text_color;
ctx.textAlign = "right";
@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action)
event,
active_widget
) {
if (!node.widgets || !node.widgets.length) {
if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) {
return null;
}
@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action)
canvas.graph.add(group);
};
/**
* Determines the furthest nodes in each direction
* @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.getBoundaryNodes = function(nodes) {
let top = null;
let right = null;
let bottom = null;
let left = null;
for (const nID in nodes) {
const node = nodes[nID];
const [x, y] = node.pos;
const [width, height] = node.size;
if (top === null || y < top.pos[1]) {
top = node;
}
if (right === null || x + width > right.pos[0] + right.size[0]) {
right = node;
}
if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) {
bottom = node;
}
if (left === null || x < left.pos[0]) {
left = node;
}
}
return {
"top": top,
"right": right,
"bottom": bottom,
"left": left
};
}
/**
* Determines the furthest nodes in each direction for the currently selected nodes
* @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}}
*/
LGraphCanvas.prototype.boundaryNodesForSelection = function() {
return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes));
}
/**
*
* @param {LGraphNode[]} nodes a list of nodes
* @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes
* @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction)
*/
LGraphCanvas.alignNodes = function (nodes, direction, align_to) {
if (!nodes) {
return;
}
const canvas = LGraphCanvas.active_canvas;
let boundaryNodes = []
if (align_to === undefined) {
boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes)
} else {
boundaryNodes = {
"top": align_to,
"right": align_to,
"bottom": align_to,
"left": align_to
}
}
for (const [_, node] of Object.entries(canvas.selected_nodes)) {
switch (direction) {
case "right":
node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0];
break;
case "left":
node.pos[0] = boundaryNodes["left"].pos[0];
break;
case "top":
node.pos[1] = boundaryNodes["top"].pos[1];
break;
case "bottom":
node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1];
break;
}
}
canvas.dirty_canvas = true;
canvas.dirty_bgcanvas = true;
};
LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node);
}
}
LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) {
new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], {
event: event,
callback: inner_clicked,
parentMenu: prev_menu,
});
function inner_clicked(value) {
LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase());
}
}
LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) {
var canvas = LGraphCanvas.active_canvas;
@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action)
options.push({ content: "Options", callback: that.showShowGraphOptionsPanel });
}*/
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align",
has_submenu: true,
callback: LGraphCanvas.onGroupAlign,
})
}
if (this._graph_stack && this._graph_stack.length > 0) {
options.push(null, {
content: "Close subgraph",
@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action)
callback: LGraphCanvas.onMenuNodeToSubgraph
});
if (Object.keys(this.selected_nodes).length > 1) {
options.push({
content: "Align Selected To",
has_submenu: true,
callback: LGraphCanvas.onNodeAlign,
})
}
options.push(null, {
content: "Remove",
disabled: !(node.removable !== false && !node.block_delete ),

View File

@ -163,7 +163,7 @@ class ComfyApi extends EventTarget {
if (res.status !== 200) {
throw {
response: await res.text(),
response: await res.json(),
};
}
}

View File

@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -26,6 +26,8 @@ export class ComfyApp {
*/
static clipspace = null;
static clipspace_invalidate_handler = null;
static open_maskeditor = null;
static clipspace_return_node = null;
constructor() {
this.ui = new ComfyUI(this);
@ -49,6 +51,114 @@ export class ComfyApp {
this.shiftDown = false;
}
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
static onClipspaceEditorSave() {
if(ComfyApp.clipspace_return_node) {
ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node);
}
}
static onClipspaceEditorClosed() {
ComfyApp.clipspace_return_node = null;
}
static copyToClipspace(node) {
var widgets = null;
if(node.widgets) {
widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined;
var orig_imgs = undefined;
if(node.imgs != undefined) {
imgs = [];
orig_imgs = [];
for (let i = 0; i < node.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = node.imgs[i].src;
orig_imgs[i] = imgs[i];
}
}
var selectedIndex = 0;
if(node.imageIndex) {
selectedIndex = node.imageIndex;
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': node.images,
'selectedIndex': selectedIndex,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
ComfyApp.clipspace_return_node = null;
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
static pasteFromClipspace(node) {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && node.imgs) {
if(node.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
node.imgs = [img];
node.imageIndex = 0;
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
node.imgs = imgs;
}
}
}
}
if(node.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = node.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
node.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
app.graph.setDirtyCanvas(true);
}
}
/**
* Invoke an extension callback
* @param {keyof ComfyExtension} method The extension callback to execute
@ -138,102 +248,30 @@ export class ComfyApp {
}
}
options.push(
{
content: "Copy (Clipspace)",
callback: (obj) => {
var widgets = null;
if(this.widgets) {
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
var imgs = undefined;
var orig_imgs = undefined;
if(this.imgs != undefined) {
imgs = [];
orig_imgs = [];
// prevent conflict of clipspace content
if(!ComfyApp.clipspace_return_node) {
options.push({
content: "Copy (Clipspace)",
callback: (obj) => { ComfyApp.copyToClipspace(this); }
});
for (let i = 0; i < this.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = this.imgs[i].src;
orig_imgs[i] = imgs[i];
if(ComfyApp.clipspace != null) {
options.push({
content: "Paste (Clipspace)",
callback: () => { ComfyApp.pasteFromClipspace(this); }
});
}
if(ComfyApp.isImageNode(this)) {
options.push({
content: "Open in MaskEditor",
callback: (obj) => {
ComfyApp.copyToClipspace(this);
ComfyApp.clipspace_return_node = this;
ComfyApp.open_maskeditor();
}
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': orig_imgs,
'images': this.images,
'selectedIndex': 0,
'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action
};
if(ComfyApp.clipspace_invalidate_handler) {
ComfyApp.clipspace_invalidate_handler();
}
}
});
if(ComfyApp.clipspace != null) {
options.push(
{
content: "Paste (Clipspace)",
callback: () => {
if(ComfyApp.clipspace) {
// image paste
if(ComfyApp.clipspace.imgs && this.imgs) {
if(this.images && ComfyApp.clipspace.images) {
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]];
}
else
app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.imgs) {
// deep-copy to cut link with clipspace
if(ComfyApp.clipspace['img_paste_mode'] == 'selected') {
const img = new Image();
img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src;
this.imgs = [img];
}
else {
const imgs = [];
for(let i=0; i<ComfyApp.clipspace.imgs.length; i++) {
imgs[i] = new Image();
imgs[i].src = ComfyApp.clipspace.imgs[i].src;
this.imgs = imgs;
}
}
}
}
if(this.widgets) {
if(ComfyApp.clipspace.images) {
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
const index = this.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0) {
this.widgets[index].value = clip_image;
}
}
if(ComfyApp.clipspace.widgets) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
if (prop && prop.type != 'button') {
prop.value = value;
prop.callback(value);
}
});
}
}
}
app.graph.setDirtyCanvas(true);
}
}
);
});
}
}
};
}
@ -864,7 +902,9 @@ export class ComfyApp {
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
const mainCanvas = document.createElement("canvas")
mainCanvas.style.touchAction = "none"
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
@ -976,7 +1016,8 @@ export class ComfyApp {
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output);
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();
@ -1237,7 +1278,7 @@ export class ComfyApp {
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
this.ui.dialog.show(error.response.error || error.toString());
break;
}
@ -1283,6 +1324,11 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
}
}
}

View File

@ -47,6 +47,22 @@ export function getPngMetadata(file) {
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const safetensorsData = new Uint8Array(event.target.result);
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
});
}
export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) {

View File

@ -465,7 +465,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png",
accept: ".json,image/png,.latent",
style: { display: "none" },
parent: document.body,
onchange: () => {

View File

@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
var v = valueControl.value;
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
switch (v) {
case "increment":
current_index += 1;
break;
case "decrement":
current_index -= 1;
break;
case "randomize":
current_index = Math.floor(Math.random() * current_length);
default:
break;
}
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
return valueControl;
};
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const t = ctx.getTransform();
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
Object.assign(this.inputEl.style, {
left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
background: (!node.color)?'':node.color,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
transformOrigin: "0 0",
transform: transform,
left: "0px",
top: "0px",
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`,
});
this.inputEl.hidden = !visible;
},

View File

@ -39,6 +39,8 @@ body {
padding: 2px;
resize: none;
border: none;
box-sizing: border-box;
font-size: 10px;
}
.comfy-modal {