diff --git a/README.md b/README.md index aab892531..d83174e3c 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - Starts up very fast. - Works fully offline: will never download anything. +- [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -136,9 +137,9 @@ This will let you use: pip3.10 to install all the dependencies. ## How to increase generation speed? -The fp16 model configs in the CheckpointLoader can be used to load them in fp16 mode, depending on your GPU this will increase your gen speed by a significant amount. +Make sure you use the CheckpointLoaderSimple node to load checkpoints. It will auto pick the right settings depending on your GPU. -You can also set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. +You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this doesn't do anything when xformers is enabled and will very likely give you black images on SD2.x models. ```--dont-upcast-attention``` diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e97badd04..23b047342 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -20,11 +20,6 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception - def exists(val): return val is not None @@ -312,7 +307,7 @@ class CrossAttentionDoggettx(nn.Module): r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except OOM_EXCEPTION as e: + except model_management.OOM_EXCEPTION as e: if first_op_done == False: torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 129b86a7f..94f5510b9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,11 +13,6 @@ if model_management.xformers_enabled(): import xformers import xformers.ops -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception - def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -221,7 +216,7 @@ class AttnBlock(nn.Module): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except OOM_EXCEPTION as e: + except model_management.OOM_EXCEPTION as e: steps *= 2 if steps > 128: raise e @@ -616,19 +611,17 @@ class Encoder(nn.Module): x = torch.nn.functional.pad(x, pad, mode="constant", value=0) already_padded = True # downsampling - hs = [self.conv_in(x)] + h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) + h = self.down[i_level].block[i_block](h, temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - hs.append(h) if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1], already_padded)) + h = self.down[i_level].downsample(h, already_padded) already_padded = False # middle - h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index edbff74a2..f3c83f387 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -24,10 +24,7 @@ except ImportError: from torch import Tensor from typing import List -try: - OOM_EXCEPTION = torch.cuda.OutOfMemoryError -except: - OOM_EXCEPTION = Exception +import model_management def dynamic_slice( x: Tensor, @@ -161,7 +158,7 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except OOM_EXCEPTION: + except model_management.OOM_EXCEPTION: print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index c26d682f7..809b19ea2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,11 @@ try: except: pass +try: + OOM_EXCEPTION = torch.cuda.OutOfMemoryError +except: + OOM_EXCEPTION = Exception + if "--disable-xformers" in sys.argv: XFORMERS_IS_AVAILBLE = False else: @@ -231,7 +236,7 @@ def should_use_fp16(): return False #FP32 is faster on those cards? - nvidia_16_series = ["1660", "1650", "1630"] + nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"] for x in nvidia_16_series: if x in props.name: return False diff --git a/comfy/sd.py b/comfy/sd.py index 6d1e8bb9b..b344cbece 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -129,12 +129,17 @@ def load_lora(path, to_load): A_name = "{}.lora_up.weight".format(x) B_name = "{}.lora_down.weight".format(x) alpha_name = "{}.alpha".format(x) + mid_name = "{}.lora_mid.weight".format(x) if A_name in lora.keys(): alpha = None if alpha_name in lora.keys(): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha) + mid = None + if mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) loaded_keys.add(A_name) loaded_keys.add(B_name) for x in lora.keys(): @@ -279,6 +284,10 @@ class ModelPatcher: mat2 = v[1] if v[2] is not None: alpha *= v[2] / mat2.shape[0] + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): @@ -374,20 +383,34 @@ class VAE: device = model_management.get_torch_device() self.device = device - def decode(self, samples): + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) + output = torch.clamp(( + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) + / 3.0) / 2.0, min=0.0, max=1.0) + return output + + def decode(self, samples_in): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - samples = samples.to(self.device) - pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) - pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + try: + samples = samples_in.to(self.device) + pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples) + pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0) + except model_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + pixel_samples = self.decode_tiled_(samples_in) + self.first_stage_model = self.first_stage_model.cpu() pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples - def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 8): + def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - output = utils.tiled_scale(samples, lambda a: torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) / 2.0, min=0.0, max=1.0), tile_x, tile_y, overlap, upscale_amount = 8) + output = self.decode_tiled_(samples, tile_x, tile_y, overlap) self.first_stage_model = self.first_stage_model.cpu() return output.movedim(1,-1) @@ -405,6 +428,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) + samples /= 3.0 self.first_stage_model = self.first_stage_model.cpu() samples = samples.cpu() return samples diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index a219c3ecd..23ee669d4 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -10,7 +10,7 @@ class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), )}, - "widget": { "Refresh": ("REFRESH", [("model_name", "upscale_models")]) }} + } RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" diff --git a/execution.py b/execution.py index 30eeb6304..757e0d9f9 100644 --- a/execution.py +++ b/execution.py @@ -143,7 +143,7 @@ class PromptExecutor: else: self.server.client_id = None - with torch.no_grad(): + with torch.inference_mode(): for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) diff --git a/nodes.py b/nodes.py index 0f9d88655..59006d946 100644 --- a/nodes.py +++ b/nodes.py @@ -31,6 +31,8 @@ def before_node_execution(): def interrupt_processing(value=True): model_management.interrupt_current_processing(value) +MAX_RESOLUTION=8192 + class CLIPTextEncode: @classmethod def INPUT_TYPES(s): @@ -59,10 +61,10 @@ class ConditioningSetArea: @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}), - "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), - "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), + "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -192,7 +194,7 @@ class CheckpointLoader: def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ) }, - "widget": { "Refresh": ("REFRESH", [("config_name", "configs"), ("ckpt_name", "checkpoints")]) }} + } RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -207,7 +209,7 @@ class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ) }, - "widget": { "Refresh": ("REFRESH", [("ckpt_name", "checkpoints")]) }} + } RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -242,7 +244,7 @@ class LoraLoader: "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) }, - "widget": { "Refresh": ("REFRESH", [("lora_name", "loras")]) }} + } RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" @@ -257,7 +259,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), ) }, - "widget": { "Refresh": ("REFRESH", [("vae_name", "vae")]) }} + } RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -273,7 +275,7 @@ class ControlNetLoader: @classmethod def INPUT_TYPES(s): return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}, - "widget": { "Refresh": ("REFRESH", [("control_net_name", "controlnet")]) }} + } RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -290,7 +292,7 @@ class DiffControlNetLoader: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "control_net_name": (folder_paths.get_filename_list("controlnet"), )}, - "widget": { "Refresh": ("REFRESH", [("control_net_name", "controlnet")]) }} + } RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -333,7 +335,8 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),}, - "widget": { "Refresh": ("REFRESH", [("clip_name", "clip")]) }} + } + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -341,14 +344,14 @@ class CLIPLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) + clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) class CLIPVisionLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), )}, - "widget": { "Refresh": ("REFRESH", [("clip_name", "clip_vision")]) } } + } RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -379,7 +382,7 @@ class StyleModelLoader: @classmethod def INPUT_TYPES(s): return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}, - "widget": { "Refresh": ("REFRESH", [("style_model_name", "style_models")]) }} + } RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -418,8 +421,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -439,8 +442,8 @@ class LatentUpscale: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -501,9 +504,9 @@ class LatentComposite: def INPUT_TYPES(s): return {"required": { "samples_to": ("LATENT",), "samples_from": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}), - "feather": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "composite" @@ -542,10 +545,10 @@ class LatentCrop: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), - "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "crop" @@ -812,7 +815,7 @@ class LoadImage: if not os.path.exists(s.input_dir): os.makedirs(s.input_dir) return {"required": {"image": (sorted(os.listdir(s.input_dir)), )}, - "widget": { "Refresh": ("REFRESH", [("image", "input")]) } } + } CATEGORY = "image" @@ -881,8 +884,8 @@ class ImageScale: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" diff --git a/web/scripts/app.js b/web/scripts/app.js index 1c2e82c8d..33b7abfb0 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -80,10 +80,23 @@ class ComfyApp { img = this.imgs[this.overIndex]; } if (img) { - options.unshift({ - content: "Open Image", - callback: () => window.open(img.src, "_blank"), - }); + options.unshift( + { + content: "Open Image", + callback: () => window.open(img.src, "_blank"), + }, + { + content: "Save Image", + callback: () => { + const a = document.createElement("a"); + a.href = img.src; + a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename")); + document.body.append(a); + a.click(); + requestAnimationFrame(() => a.remove()); + }, + } + ); } } }; @@ -481,6 +494,7 @@ class ComfyApp { // Create and mount the LiteGraph in the DOM const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + canvasEl.tabIndex = "1" document.body.prepend(canvasEl); this.graph = new LGraph(); @@ -760,6 +774,71 @@ class ComfyApp { } this.extensions.push(extension); } + + /** + * Refresh file list on whole nodes + */ + async refreshNodes() { + for(let nodeNum in this.graph._nodes) { + const node = this.graph._nodes[nodeNum]; + + var data = []; + + switch(node.type) { + case "CheckpointLoader": + data = { "config_name": "configs", + "ckpt_name": "checkpoints" }; + break; + + case "CheckpointLoaderSimple": + data = { "ckpt_name": "checkpoints" }; + break; + + case "LoraLoader": + data = { "lora_name": "loras" }; + break; + + case "VAELoader": + data = { "vae_name": "vae" }; + break; + + case "ControlNetLoader": + case "DiffControlNetLoader": + data = { "control_net_name": "controlnet" }; + break; + + case "CLIPLoader": + data = { "clip_name": "clip" }; + break; + + case "CLIPVisionLoader": + data = { "clip_name": "clip_vision" }; + break; + + case "StyleModelLoader": + data = { "style_model_name": "style_models" }; + break; + + case "LoadImage": + data = { "image": "input" }; + break; + + case "UpscaleModelLoader": + data = { "model_name": "upscale_models" }; + break; + + default: + break; + } + + for (let i in data) { + const w = node.widgets.find((w) => w.name === i); + const filelist = await api.getFiles(data[i]); + w.options.values = filelist.files; + w.value = filelist.files[0]; + } + } + } } export const app = new ComfyApp(); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 58012fe6c..a66419b89 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -326,6 +326,7 @@ export class ComfyUI { }, 0); }, }), + $el("button", { textContent: "Refresh", onclick: () => app.refreshNodes() }), $el("button", { textContent: "Load", onclick: () => fileInput.click() }), $el("button", { textContent: "Clear", onclick: () => app.graph.clear() }), $el("button", { textContent: "Load Default", onclick: () => app.loadGraphData() }), diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index b48ff50d7..30a02e72e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,5 +1,3 @@ -import { api } from "./api.js"; - function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; let { min, max, step } = inputData[1]; @@ -29,22 +27,59 @@ function seedWidget(node, inputName, inputData) { return { widget: seed, randomize }; } -function refreshWidget(node, name, data) { - async function refresh_callback() { - const items = data[1]; - for (let i in items) { - const w = node.widgets.find((w) => w.name === items[i][0]); - const filelist = await api.getFiles(items[i][1]); - w.options.values = filelist.files; - w.value = filelist.files[0]; +const MultilineSymbol = Symbol(); + +function addMultilineWidget(node, name, opts, app) { + const MIN_SIZE = 50; + + function computeSize(size) { + if (node.widgets[0].last_y == null) return; + + let y = node.widgets[0].last_y; + let freeSpace = size[1] - y; + + // Compute the height of all non customtext widgets + let widgetHeight = 0; + const multi = []; + for (let i = 0; i < node.widgets.length; i++) { + const w = node.widgets[i]; + if (w.type === "customtext") { + multi.push(w); + } else { + if (w.computeSize) { + widgetHeight += w.computeSize()[1] + 4; + } else { + widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } } + + // See how large each text input can be + freeSpace -= widgetHeight; + freeSpace /= multi.length; + + if (freeSpace < MIN_SIZE) { + // There isnt enough space for all the widgets, increase the size of the node + freeSpace = MIN_SIZE; + node.size[1] = y + widgetHeight + freeSpace * multi.length; + node.graph.setDirtyCanvas(true); + } + + // Position each of the widgets + for (const w of node.widgets) { + w.y = y; + if (w.type === "customtext") { + y += freeSpace; + } else if (w.computeSize) { + y += w.computeSize()[1] + 4; + } else { + y += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } + + node.inputHeight = freeSpace; } - const refresh = node.addWidget("button", name, true, function(v) { refresh_callback(); }, {}); - return { refresh }; -} - -function addMultilineWidget(node, name, defaultVal, app) { const widget = { type: "customtext", name, @@ -55,14 +90,19 @@ function addMultilineWidget(node, name, defaultVal, app) { this.inputEl.value = x; }, draw: function (ctx, _, widgetWidth, y, widgetHeight) { + if (!this.parent.inputHeight) { + // If we are initially offscreen when created we wont have received a resize event + // Calculate it here instead + computeSize(node.size); + } const visible = app.canvas.ds.scale > 0.5; const t = ctx.getTransform(); const margin = 10; Object.assign(this.inputEl.style, { left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin) + t.f}px`, + top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - height: `${(this.parent.size[1] - (y + widgetHeight) - 3) * t.d}px`, + height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, position: "absolute", zIndex: 1, fontSize: `${t.d * 10.0}px`, @@ -72,7 +112,8 @@ function addMultilineWidget(node, name, defaultVal, app) { }; widget.inputEl = document.createElement("textarea"); widget.inputEl.className = "comfy-multiline-input"; - widget.inputEl.value = defaultVal; + widget.inputEl.value = opts.defaultVal; + widget.inputEl.placeholder = opts.placeholder || ""; document.addEventListener("mousedown", function (event) { if (!widget.inputEl.contains(event.target)) { widget.inputEl.blur(); @@ -108,6 +149,20 @@ function addMultilineWidget(node, name, defaultVal, app) { } }; + if (!(MultilineSymbol in node)) { + node[MultilineSymbol] = true; + const onResize = node.onResize; + + node.onResize = function (size) { + computeSize(size); + + // Call original resizer handler + if (onResize) { + onResize.apply(this, arguments); + } + }; + } + return { minWidth: 400, minHeight: 200, widget }; } @@ -120,6 +175,7 @@ export const ComfyWidgets = { }, INT(node, inputName, inputData) { const { val, config } = getNumberDefaults(inputData, 1); + Object.assign(config, { precision: 0 }); return { widget: node.addWidget( "number", @@ -133,13 +189,12 @@ export const ComfyWidgets = { ), }; }, - REFRESH:refreshWidget, STRING(node, inputName, inputData, app) { const defaultVal = inputData[1].default || ""; const multiline = !!inputData[1].multiline; if (multiline) { - return addMultilineWidget(node, inputName, defaultVal, app); + return addMultilineWidget(node, inputName, { defaultVal, ...inputData[1] }, app); } else { return { widget: node.addWidget("text", inputName, defaultVal, () => {}, {}) }; } @@ -151,7 +206,7 @@ export const ComfyWidgets = { function showImage(name) { // Position the image somewhere sensible if (!node.imageOffset) { - node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 50 : 100; + node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75; } const img = new Image();