Merge branch 'master' into aitemplate

This commit is contained in:
hlky 2023-05-22 08:42:54 +01:00 committed by GitHub
commit 7e4da3c48a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 367 additions and 164 deletions

View File

@ -605,3 +605,47 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
return x+h
def slice_attention(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
scale = (int(q.shape[-1])**(-0.5))
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
return r1
class AttnBlock(nn.Module):
def __init__(self, in_channels):
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
# compute attention
b,c,h,w = q.shape
scale = (int(c)**(-0.5))
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = torch.zeros_like(k, device=q.device)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_)
return x+h_
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out

View File

@ -555,10 +555,10 @@ class AITemplateModelWrapper:
return noise_pred
class KSampler:
SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"]
SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
"dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}, aitemplate=None, cfg=None):
self.model = model
@ -596,6 +596,8 @@ class KSampler:
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple":

View File

@ -59,6 +59,12 @@ class Blend:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def gaussian_kernel(kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
class Blur:
def __init__(self):
pass
@ -88,12 +94,6 @@ class Blur:
CATEGORY = "image/postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0:
return (image,)
@ -101,10 +101,11 @@ class Blur:
batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
@ -167,9 +168,15 @@ class Sharpen:
"max": 31,
"step": 1
}),
"alpha": ("FLOAT", {
"sigma": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.1
}),
@ -181,21 +188,21 @@ class Sharpen:
CATEGORY = "image/postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
if sharpen_radius == 0:
return (image,)
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1)

View File

@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated):
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
return (False, "Value smaller than min. {}, {}".format(class_type, x))
return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x))
if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x))
return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x))
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)

View File

@ -148,4 +148,37 @@ def get_filename_list(folder_name):
output_list.update(filter_files_extensions(recursive_search(x), folders[1]))
return sorted(list(output_list))
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input, image_width, image_height):
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return full_output_folder, filename, counter, subfolder, filename_prefix

122
nodes.py
View File

@ -7,16 +7,15 @@ import hashlib
import traceback
import math
import time
from comfy.aitemplate.model import Model
from diffusers import LMSDiscreteScheduler
from PIL import Image
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
from comfy.aitemplate.model import Model
import comfy.diffusers_convert
import comfy.samplers
import comfy.sample
@ -30,6 +29,7 @@ import importlib
import folder_paths
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@ -247,6 +247,81 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# support save metadata for latent sharing
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file)
output = {}
output["latent_tensor"] = samples["samples"]
safetensors.torch.save_file(output, file, metadata=metadata)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"].float()}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
if not folder_paths.exists_annotated_filepath(latent):
return "Invalid latent file: {}".format(latent)
return True
class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
@ -1258,39 +1333,7 @@ class SaveImage:
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
def map_filename(filename):
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
def compute_vars(input):
input = input.replace("%width%", str(images[0].shape[1]))
input = input.replace("%height%", str(images[0].shape[0]))
return input
filename_prefix = compute_vars(filename_prefix)
subfolder = os.path.dirname(os.path.normpath(filename_prefix))
filename = os.path.basename(os.path.normpath(filename_prefix))
full_output_folder = os.path.join(self.output_dir, subfolder)
if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
counter = 1
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
for image in images:
i = 255. * image.cpu().numpy()
@ -1341,6 +1384,7 @@ class LoadImage:
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
@ -1384,6 +1428,7 @@ class LoadImageMask:
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA")
mask = None
@ -1566,6 +1611,9 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -175,6 +175,8 @@
"import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n",
"def iframe_thread(port):\n",
" while True:\n",
" time.sleep(0.5)\n",
@ -183,7 +185,9 @@
" if result == 0:\n",
" break\n",
" sock.close()\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n",
" print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n",
"\n",
" print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n",
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
" for line in p.stdout:\n",
" print(line.decode(), end='')\n",

View File

@ -261,23 +261,34 @@ class PromptServer():
async def get_prompt(request):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
return info
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
obj_class = nodes.NODE_CLASS_MAPPINGS[x]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
info['description'] = ''
info['category'] = 'sd'
if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY
out[x] = info
out[x] = node_info(x)
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
return web.json_response(out)
@routes.get("/history")
@ -320,7 +331,8 @@ class PromptServer():
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
prompt_id = str(uuid.uuid4())
self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2]))
outputs_to_execute = valid[2]
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
return web.json_response({"prompt_id": prompt_id})
else:
print("invalid prompt:", valid[1])

View File

@ -300,7 +300,7 @@ app.registerExtension({
}
}
if (widget.type === "number") {
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
}

View File

@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action)
if (show_text) {
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7);
ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7);
}
break;
case "toggle":
@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
if (show_text) {
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = w.value ? text_color : secondary_text_color;
ctx.textAlign = "right";
@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.textAlign = "center";
ctx.fillStyle = text_color;
ctx.fillText(
w.name + " " + Number(w.value).toFixed(3),
w.label || w.name + " " + Number(w.value).toFixed(3),
widget_width * 0.5,
y + H * 0.7
);
@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action)
ctx.fill();
}
ctx.fillStyle = secondary_text_color;
ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7);
ctx.fillStyle = text_color;
ctx.textAlign = "right";
if (w.type == "number") {
@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action)
//ctx.stroke();
ctx.fillStyle = secondary_text_color;
if (w.name != null) {
ctx.fillText(w.name, margin * 2, y + H * 0.7);
const label = w.label || w.name;
if (label != null) {
ctx.fillText(label, margin * 2, y + H * 0.7);
}
ctx.fillStyle = text_color;
ctx.textAlign = "right";

View File

@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111 } from "./pnginfo.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -902,7 +902,9 @@ export class ComfyApp {
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
const mainCanvas = document.createElement("canvas")
mainCanvas.style.touchAction = "none"
const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" }));
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
@ -1306,6 +1308,11 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
}
}
}

View File

@ -47,6 +47,22 @@ export function getPngMetadata(file) {
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const safetensorsData = new Uint8Array(event.target.result);
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};
reader.readAsArrayBuffer(file);
});
}
export async function importA1111(graph, parameters) {
const p = parameters.lastIndexOf("\nSteps:");
if (p > -1) {

View File

@ -465,7 +465,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png",
accept: ".json,image/png,.latent",
style: { display: "none" },
parent: document.body,
onchange: () => {

View File

@ -19,35 +19,60 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
var v = valueControl.value;
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
switch (v) {
case "increment":
current_index += 1;
break;
case "decrement":
current_index -= 1;
break;
case "randomize":
current_index = Math.floor(Math.random() * current_length);
default:
break;
}
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
max = Math.min(1125899906842624, max);
min = Math.max(-1125899906842624, min);
let range = (max - min) / (targetWidget.options.step / 10);
//adjust values based on valueControl Behaviour
switch (v) {
case "fixed":
break;
case "increment":
targetWidget.value += targetWidget.options.step / 10;
break;
case "decrement":
targetWidget.value -= targetWidget.options.step / 10;
break;
case "randomize":
targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min;
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
}
return valueControl;
};
@ -130,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) {
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const t = ctx.getTransform();
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
Object.assign(this.inputEl.style, {
left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
background: (!node.color)?'':node.color,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
transformOrigin: "0 0",
transform: transform,
left: "0px",
top: "0px",
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`,
});
this.inputEl.hidden = !visible;
},

View File

@ -39,6 +39,8 @@ body {
padding: 2px;
resize: none;
border: none;
box-sizing: border-box;
font-size: 10px;
}
.comfy-modal {