diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..26930428f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + h = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60f..91e7d60ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -146,6 +146,41 @@ class ResnetBlock(nn.Module): return x+h +def slice_attention(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1])**(-0.5)) + + mem_free_total = model_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except model_management.OOM_EXCEPTION as e: + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again", steps) + + return r1 class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -183,48 +218,15 @@ class AttnBlock(nn.Module): # compute attention b,c,h,w = q.shape - scale = (int(c)**(-0.5)) q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw v = v.reshape(b,c,h*w) - r1 = torch.zeros_like(k, device=q.device) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = torch.bmm(q[:, i:end], k) * scale - - s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) - del s1 - - r1[:, :, i:end] = torch.bmm(v, s2) - del s2 - break - except model_management.OOM_EXCEPTION as e: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - + r1 = slice_attention(q, k, v) h_ = r1.reshape(b,c,h,w) del r1 - h_ = self.proj_out(h_) return x+h_ @@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..c15323219 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -127,6 +127,32 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() + +def get_torch_device_name(device): + if hasattr(device, 'type'): + return "{}".format(device.type) + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Using device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + current_loaded_model = None current_gpu_controlnets = [] @@ -233,22 +259,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy/samplers.py b/comfy/samplers.py index 6417f2ed4..1fb928f8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: @@ -471,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -508,6 +532,8 @@ class KSampler: if self.scheduler == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": diff --git a/comfy/sd.py b/comfy/sd.py index 3543bdb77..c6be900ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -581,12 +581,9 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor @@ -623,7 +620,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -794,10 +793,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..9916f3b21 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..3be141dfe 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,6 +59,12 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) +def gaussian_kernel(kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + class Blur: def __init__(self): pass @@ -88,12 +94,6 @@ class Blur: CATEGORY = "image/postprocessing" - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) @@ -101,10 +101,11 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) @@ -167,9 +168,15 @@ class Sharpen: "max": 31, "step": 1 }), - "alpha": ("FLOAT", { + "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, "max": 5.0, "step": 0.1 }), @@ -181,21 +188,21 @@ class Sharpen: CATEGORY = "image/postprocessing" - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): if sharpen_radius == 0: return (image,) batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) diff --git a/execution.py b/execution.py index c19c10bc6..35f044346 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -26,21 +27,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -55,21 +117,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -105,7 +166,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -144,10 +206,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -155,6 +218,10 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -169,32 +236,34 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) except Exception as e: - print(traceback.format_exc()) + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): @@ -210,14 +279,18 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -238,8 +311,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -254,20 +328,25 @@ def validate_inputs(prompt, item): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -281,11 +360,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -294,7 +374,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -304,7 +384,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: @@ -340,8 +420,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..28f117824 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -147,4 +147,37 @@ def get_filename_list(folder_name): output_list.update(filter_files_extensions(recursive_search(x), folders[1])) return sorted(list(output_list)) +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/main.py b/main.py index eb97a2fb8..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -33,8 +33,8 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) - q.task_done(item_id, e.outputs) + e.execute(item[2], item[1], item[3], item[4]) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) diff --git a/nodes.py b/nodes.py index 760db24e1..bae330bc9 100644 --- a/nodes.py +++ b/nodes.py @@ -6,10 +6,12 @@ import json import hashlib import traceback import math +import time -from PIL import Image +from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -28,6 +30,7 @@ import importlib import folder_paths + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -145,9 +148,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -160,9 +160,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -175,9 +172,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -202,9 +196,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -219,9 +210,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} @@ -260,6 +248,81 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"prompt": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) + + return {} + + +class LoadLatent: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") + samples = {"samples": latent["latent_tensor"].float()} + return (samples, ) + + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -296,7 +359,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -306,9 +372,9 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) @@ -629,18 +695,57 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -795,7 +900,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): @@ -805,8 +910,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -901,39 +1006,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() @@ -984,6 +1057,7 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1027,6 +1101,7 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") mask = None @@ -1170,6 +1245,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1206,6 +1282,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1244,6 +1323,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1275,14 +1356,18 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1291,11 +1376,25 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - load_custom_node(module_path) + if module_path.endswith(".disabled"): continue + time_before = time.perf_counter() + success = load_custom_node(module_path) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) + + if len(node_import_times) > 0: + print("\nImport times for custom nodes:") + for n in sorted(node_import_times): + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) + print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", diff --git a/server.py b/server.py index c1226f304..701c0e7a7 100644 --- a/server.py +++ b/server.py @@ -81,7 +81,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -115,21 +115,23 @@ class PromptServer(): def get_dir_by_type(dir_type): if dir_type is None: - type_dir = folder_paths.get_input_directory() - elif dir_type == "input": + dir_type = "input" + + if dir_type == "input": type_dir = folder_paths.get_input_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": type_dir = folder_paths.get_output_directory() - return type_dir + return type_dir, dir_type def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") image_upload_type = post.get("type") - upload_dir = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename @@ -148,10 +150,14 @@ class PromptServer(): split = os.path.splitext(filename) filepath = os.path.join(full_output_folder, filename) - i = 1 - while os.path.exists(filepath): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 if image_save_function is not None: image_save_function(image, post, filepath) @@ -255,22 +261,34 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) + def node_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = x - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info + out[x] = node_info(x) + return web.json_response(out) + + @routes.get("/object_info/{node_class}") + async def get_object_info_node(request): + node_class = request.match_info.get("node_class", None) + out = {} + if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + out[node_class] = node_info(node_class) return web.json_response(out) @routes.get("/history") @@ -312,14 +330,16 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + prompt_id = str(uuid.uuid4()) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + return web.json_response({"prompt_id": prompt_id}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1]}, status=400) + else: + return web.json_response({"error": "no prompt"}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() @@ -329,9 +349,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") @@ -355,7 +375,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 552059e86..4b0c12747 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) { class MaskEditorDialog extends ComfyDialog { static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + constructor() { super(); this.element = $el("div.comfy-modal", { parent: document.body }, [ $el("div.comfy-modal-content", [...this.createButtons()]), ]); - MaskEditorDialog.instance = this; } createButtons() { return []; } - clearMask(self) { - } - createButton(name, callback) { var button = document.createElement("button"); button.innerText = name; button.addEventListener("click", callback); return button; } + createLeftButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "left"; button.style.marginRight = "4px"; return button; } + createRightButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "right"; button.style.marginLeft = "4px"; return button; } + createLeftSlider(self, name, callback) { const divElement = document.createElement('div'); divElement.id = "maskeditor-slider"; @@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog { brush.style.MozBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%"; brush.style.position = "absolute"; - brush.style.zIndex = 100; + brush.style.zIndex = 8889; brush.style.pointerEvents = "none"; this.brush = brush; this.element.appendChild(imgCanvas); @@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog { document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.close(); }); - var saveButton = this.createRightButton("Save", () => { + + this.saveButton = this.createRightButton("Save", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.save(); @@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); - bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(brush_size_slider); - this.element.style.display = "block"; imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; @@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog { } show() { - // layout - const imgCanvas = document.createElement('canvas'); - const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); - imgCanvas.id = "imageCanvas"; - maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; - this.setlayout(imgCanvas, maskCanvas); + this.setlayout(imgCanvas, maskCanvas); - // prepare content - this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); - this.setImages(imgCanvas, backupCanvas); - this.setEventHandler(maskCanvas); + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; } setImages(imgCanvas, backupCanvas) { @@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog { const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + // image load const orig_image = new Image(); window.addEventListener("resize", () => { @@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog { rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; - }g - + } setEventHandler(maskCanvas) { maskCanvas.addEventListener("contextmenu", (event) => { @@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog { self.brush_size = Math.min(self.brush_size+2, 100); } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); } self.updateBrushPreview(self); @@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog { } } - save() { + async save() { const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); @@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog { formData.append('type', "input"); formData.append('subfolder', "clipspace"); - uploadMask(item, formData); + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); this.close(); } } @@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog { app.registerExtension({ name: "Comfy.MaskEditor", init(app) { - const callback = + ComfyApp.open_maskeditor = function () { - let dlg = new MaskEditorDialog(app); - dlg.show(); + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } }; const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 - ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } }); \ No newline at end of file diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; @@ -9911,7 +9913,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10302,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13015,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13137,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..4f061c358 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -163,7 +163,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index adb14c605..807301b80 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -26,6 +26,8 @@ export class ComfyApp { */ static clipspace = null; static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -49,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -138,102 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - var imgs = undefined; - var orig_imgs = undefined; - if(this.imgs != undefined) { - imgs = []; - orig_imgs = []; + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - for (let i = 0; i < this.imgs.length; i++) { - imgs[i] = new Image(); - imgs[i].src = this.imgs[i].src; - orig_imgs[i] = imgs[i]; + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } + + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': orig_imgs, - 'images': this.images, - 'selectedIndex': 0, - 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action - }; - - if(ComfyApp.clipspace_invalidate_handler) { - ComfyApp.clipspace_invalidate_handler(); - } - } - }); - - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace) { - // image paste - if(ComfyApp.clipspace.imgs && this.imgs) { - if(this.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; - - } - else - app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.imgs) { - // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - const img = new Image(); - img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; - this.imgs = [img]; - } - else { - const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - this.widgets[index].value = clip_image; - } - } - if(ComfyApp.clipspace.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } - }); - } - } - } - - app.graph.setDirtyCanvas(true); - } - } - ); + }); + } } }; } @@ -864,7 +902,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); @@ -976,7 +1016,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); @@ -1237,7 +1278,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + this.ui.dialog.show(error.response.error || error.toString()); break; } @@ -1283,6 +1324,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..8ddb7a1c5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,22 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 77517aec1..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 65edc0392..82168b08b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; @@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) { computeSize(node.size); } const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const t = ctx.getTransform(); const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + Object.assign(this.inputEl.style, { - left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, - width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - background: (!node.color)?'':node.color, - height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, + transformOrigin: "0 0", + transform: transform, + left: "0px", + top: "0px", + width: `${widgetWidth - (margin * 2)}px`, + height: `${this.parent.inputHeight - (margin * 2)}px`, position: "absolute", + background: (!node.color)?'':node.color, color: (!node.color)?'':'white', zIndex: app.graph._nodes.indexOf(node), - fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; }, diff --git a/web/style.css b/web/style.css index df220cc02..87f096e14 100644 --- a/web/style.css +++ b/web/style.css @@ -39,6 +39,8 @@ body { padding: 2px; resize: none; border: none; + box-sizing: border-box; + font-size: 10px; } .comfy-modal {