mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 08:52:34 +08:00
Merge branch 'master' into aitemplate
This commit is contained in:
commit
7e4da3c48a
@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
return x+h
|
||||
|
||||
def slice_attention(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1])**(-0.5))
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
return r1
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
scale = (int(c)**(-0.5))
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
|
||||
@ -555,10 +555,10 @@ class AITemplateModelWrapper:
|
||||
return noise_pred
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
|
||||
self.model = model
|
||||
@ -596,6 +596,8 @@ class KSampler:
|
||||
|
||||
if self.scheduler == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||
elif self.scheduler == "normal":
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
elif self.scheduler == "simple":
|
||||
|
||||
@ -59,6 +59,12 @@ class Blend:
|
||||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
class Blur:
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -88,12 +94,6 @@ class Blur:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
|
||||
if blur_radius == 0:
|
||||
return (image,)
|
||||
@ -101,10 +101,11 @@ class Blur:
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
|
||||
blurred = blurred.permute(0, 2, 3, 1)
|
||||
|
||||
return (blurred,)
|
||||
@ -167,9 +168,15 @@ class Sharpen:
|
||||
"max": 31,
|
||||
"step": 1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"sigma": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.1,
|
||||
"max": 10.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
"alpha": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 5.0,
|
||||
"step": 0.1
|
||||
}),
|
||||
@ -181,21 +188,21 @@ class Sharpen:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
|
||||
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
|
||||
if sharpen_radius == 0:
|
||||
return (image,)
|
||||
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
||||
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel_size**2
|
||||
kernel *= alpha
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
||||
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||
|
||||
result = torch.clamp(sharpened, 0, 1)
|
||||
|
||||
@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated):
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
return (False, "Value smaller than min. {}, {}".format(class_type, x))
|
||||
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
|
||||
@ -148,4 +148,37 @@ def get_filename_list(folder_name):
|
||||
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
|
||||
return sorted(list(output_list))
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input, image_width, image_height):
|
||||
input = input.replace("%width%", str(image_width))
|
||||
input = input.replace("%height%", str(image_height))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return full_output_folder, filename, counter, subfolder, filename_prefix
|
||||
|
||||
122
nodes.py
122
nodes.py
@ -7,16 +7,15 @@ import hashlib
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
from comfy.aitemplate.model import Model
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
from comfy.aitemplate.model import Model
|
||||
import comfy.diffusers_convert
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
@ -30,6 +29,7 @@ import importlib
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
@ -247,6 +247,81 @@ class VAEEncodeForInpaint:
|
||||
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||
|
||||
|
||||
class SaveLatent:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ),
|
||||
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
# support save metadata for latent sharing
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {"prompt": prompt_info}
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
file = f"{filename}_{counter:05}_.latent"
|
||||
file = os.path.join(full_output_folder, file)
|
||||
|
||||
output = {}
|
||||
output["latent_tensor"] = samples["samples"]
|
||||
|
||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class LoadLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||
return {"required": {"latent": [sorted(files), ]}, }
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
RETURN_TYPES = ("LATENT", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, latent):
|
||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||
samples = {"samples": latent["latent_tensor"].float()}
|
||||
return (samples, )
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, latent):
|
||||
image_path = folder_paths.get_annotated_filepath(latent)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, latent):
|
||||
if not folder_paths.exists_annotated_filepath(latent):
|
||||
return "Invalid latent file: {}".format(latent)
|
||||
return True
|
||||
|
||||
|
||||
class CheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1258,39 +1333,7 @@ class SaveImage:
|
||||
CATEGORY = "image"
|
||||
|
||||
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
def map_filename(filename):
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return (digits, prefix)
|
||||
|
||||
def compute_vars(input):
|
||||
input = input.replace("%width%", str(images[0].shape[1]))
|
||||
input = input.replace("%height%", str(images[0].shape[0]))
|
||||
return input
|
||||
|
||||
filename_prefix = compute_vars(filename_prefix)
|
||||
|
||||
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
|
||||
filename = os.path.basename(os.path.normpath(filename_prefix))
|
||||
|
||||
full_output_folder = os.path.join(self.output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
counter = 1
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
@ -1341,6 +1384,7 @@ class LoadImage:
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
@ -1384,6 +1428,7 @@ class LoadImageMask:
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
@ -1566,6 +1611,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"CheckpointLoader": CheckpointLoader,
|
||||
"DiffusersLoader": DiffusersLoader,
|
||||
|
||||
"LoadLatent": LoadLatent,
|
||||
"SaveLatent": SaveLatent
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
@ -175,6 +175,8 @@
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
@ -183,7 +185,9 @@
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
|
||||
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
|
||||
"\n",
|
||||
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end='')\n",
|
||||
|
||||
40
server.py
40
server.py
@ -261,23 +261,34 @@ class PromptServer():
|
||||
async def get_prompt(request):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = x
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
|
||||
info['description'] = ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
out[x] = info
|
||||
out[x] = node_info(x)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
node_class = request.match_info.get("node_class", None)
|
||||
out = {}
|
||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/history")
|
||||
@ -320,7 +331,8 @@ class PromptServer():
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
return web.json_response({"prompt_id": prompt_id})
|
||||
else:
|
||||
print("invalid prompt:", valid[1])
|
||||
|
||||
@ -300,7 +300,7 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
if (widget.type === "number") {
|
||||
if (widget.type === "number" || widget.type === "combo") {
|
||||
addValueControlWidget(this, widget, "fixed");
|
||||
}
|
||||
|
||||
|
||||
@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
if (show_text) {
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
|
||||
}
|
||||
break;
|
||||
case "toggle":
|
||||
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
if (show_text) {
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = w.value ? text_color : secondary_text_color;
|
||||
ctx.textAlign = "right";
|
||||
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.fillText(
|
||||
w.name + " " + Number(w.value).toFixed(3),
|
||||
w.label || w.name + " " + Number(w.value).toFixed(3),
|
||||
widget_width * 0.5,
|
||||
y + H * 0.7
|
||||
);
|
||||
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
ctx.fill();
|
||||
}
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
if (w.type == "number") {
|
||||
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
|
||||
//ctx.stroke();
|
||||
ctx.fillStyle = secondary_text_color;
|
||||
if (w.name != null) {
|
||||
ctx.fillText(w.name, margin * 2, y + H * 0.7);
|
||||
const label = w.label || w.name;
|
||||
if (label != null) {
|
||||
ctx.fillText(label, margin * 2, y + H * 0.7);
|
||||
}
|
||||
ctx.fillStyle = text_color;
|
||||
ctx.textAlign = "right";
|
||||
|
||||
@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
|
||||
import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, importA1111 } from "./pnginfo.js";
|
||||
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
|
||||
/**
|
||||
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
|
||||
@ -902,7 +902,9 @@ export class ComfyApp {
|
||||
await this.#loadExtensions();
|
||||
|
||||
// Create and mount the LiteGraph in the DOM
|
||||
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
|
||||
const mainCanvas = document.createElement("canvas")
|
||||
mainCanvas.style.touchAction = "none"
|
||||
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
@ -1306,6 +1308,11 @@ export class ComfyApp {
|
||||
this.loadGraphData(JSON.parse(reader.result));
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -47,6 +47,22 @@ export function getPngMetadata(file) {
|
||||
});
|
||||
}
|
||||
|
||||
export function getLatentMetadata(file) {
|
||||
return new Promise((r) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
const safetensorsData = new Uint8Array(event.target.result);
|
||||
const dataView = new DataView(safetensorsData.buffer);
|
||||
let header_size = dataView.getUint32(0, true);
|
||||
let offset = 8;
|
||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
||||
r(header.__metadata__);
|
||||
};
|
||||
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
export async function importA1111(graph, parameters) {
|
||||
const p = parameters.lastIndexOf("\nSteps:");
|
||||
if (p > -1) {
|
||||
|
||||
@ -465,7 +465,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png",
|
||||
accept: ".json,image/png,.latent",
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
|
||||
@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
|
||||
var v = valueControl.value;
|
||||
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
switch (v) {
|
||||
case "increment":
|
||||
current_index += 1;
|
||||
break;
|
||||
case "decrement":
|
||||
current_index -= 1;
|
||||
break;
|
||||
case "randomize":
|
||||
current_index = Math.floor(Math.random() * current_length);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
} else { //number
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
max = Math.min(1125899906842624, max);
|
||||
min = Math.max(-1125899906842624, min);
|
||||
let range = (max - min) / (targetWidget.options.step / 10);
|
||||
|
||||
//adjust values based on valueControl Behaviour
|
||||
switch (v) {
|
||||
case "fixed":
|
||||
break;
|
||||
case "increment":
|
||||
targetWidget.value += targetWidget.options.step / 10;
|
||||
break;
|
||||
case "decrement":
|
||||
targetWidget.value -= targetWidget.options.step / 10;
|
||||
break;
|
||||
case "randomize":
|
||||
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
}
|
||||
return valueControl;
|
||||
};
|
||||
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
|
||||
computeSize(node.size);
|
||||
}
|
||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||
const t = ctx.getTransform();
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
Object.assign(this.inputEl.style, {
|
||||
left: `${t.a * margin + t.e}px`,
|
||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
||||
background: (!node.color)?'':node.color,
|
||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
||||
transformOrigin: "0 0",
|
||||
transform: transform,
|
||||
left: "0px",
|
||||
top: "0px",
|
||||
width: `${widgetWidth - (margin * 2)}px`,
|
||||
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||
position: "absolute",
|
||||
background: (!node.color)?'':node.color,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
fontSize: `${t.d * 10.0}px`,
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
},
|
||||
|
||||
@ -39,6 +39,8 @@ body {
|
||||
padding: 2px;
|
||||
resize: none;
|
||||
border: none;
|
||||
box-sizing: border-box;
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.comfy-modal {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user