From 468dd3862899e7d2fd998317423c41e403120009 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Mon, 20 Mar 2023 15:41:02 +0900 Subject: [PATCH] Added file reload feature to widgets. --- comfy_extras/nodes_upscale_model.py | 1 + nodes.py | 29 ++++++++++++++++----- server.py | 40 +++++++++++++++++++++++++++++ web/scripts/api.js | 9 +++++++ web/scripts/widgets.js | 18 +++++++++++++ 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index b79b78511..f2d3cc8b6 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -10,6 +10,7 @@ class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), + "RELOAD": ("RELOAD", [("model_name", "upscale_models")]), }} RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" diff --git a/nodes.py b/nodes.py index 7589a0abb..4167c39a5 100644 --- a/nodes.py +++ b/nodes.py @@ -191,7 +191,8 @@ class CheckpointLoader: @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}} + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "RELOAD": ("RELOAD", [("config_name", "configs"), ("ckpt_name", "checkpoints")]) }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -206,7 +207,7 @@ class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + "RELOAD": ("RELOAD", [("ckpt_name", "checkpoints")]) }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -241,6 +242,7 @@ class LoraLoader: "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "RELOAD": ("RELOAD", [("lora_name", "loras")]) }} RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" @@ -255,7 +257,9 @@ class LoraLoader: class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} + return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), ), + "RELOAD": ("RELOAD", [("vae_name", "vae")]) + }} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -270,7 +274,9 @@ class VAELoader: class ControlNetLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} + return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), ), + "RELOAD": ("RELOAD", [("control_net_name", "controlnet")]) + }} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -286,7 +292,9 @@ class DiffControlNetLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} + "control_net_name": (folder_paths.get_filename_list("controlnet"), ), + "RELOAD": ("RELOAD", [("control_net_name", "controlnet")]) + }} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -329,6 +337,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "RELOAD": ("RELOAD", [("clip_name", "clip")]) }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -344,6 +353,7 @@ class CLIPVisionLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ), + "RELOAD": ("RELOAD", [("clip_name", "clip_vision")]) }} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -373,7 +383,9 @@ class CLIPVisionEncode: class StyleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}} + return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), ), + "RELOAD": ("RELOAD", [("style_model_name", "style_models")]) + }} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -790,7 +802,10 @@ class LoadImage: if not os.path.exists(s.input_dir): os.makedirs(s.input_dir) return {"required": - {"image": (sorted(os.listdir(s.input_dir)), )}, + {"image": (sorted(os.listdir(s.input_dir)), ), + "RELOAD": ("RELOAD", [("image", "input")]) + }, + } CATEGORY = "image" diff --git a/server.py b/server.py index e2d129e3f..93f8d0f5f 100644 --- a/server.py +++ b/server.py @@ -151,6 +151,46 @@ class PromptServer(): out[x] = info return web.json_response(out) + @routes.get("/getfiles/{kind}") + async def get_filelist(request): + out = {} + + out["files"] = {} + if "kind" in request.match_info: + kind = request.match_info["kind"] + path = "" + + # whitelist policy for security reason + if kind == "checkpoints": + path = kind + elif kind == "loras": + path = kind + elif kind == "vae": + path = kind + elif kind == "controlnet": + path = kind + elif kind == "clip": + path = kind + elif kind == "clip_vision": + path = kind + elif kind == "style_models": + path = kind + elif kind == "upscale_models": + path = kind + elif kind == "input": + path = kind # must not be empty + else: + path = "" + + if path != "": + if path == "input": + input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + out["files"] = sorted(os.listdir(input_dir)) + else: + out["files"] = folder_paths.get_filename_list(path) + + return web.json_response(out) + @routes.get("/history") async def get_history(request): return web.json_response(self.prompt_queue.get_history()) diff --git a/web/scripts/api.js b/web/scripts/api.js index b90b1c656..cd28d1d29 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -124,6 +124,15 @@ class ComfyApi extends EventTarget { return await resp.json(); } + /** + * Loads file list + * @returns An array of file list + */ + async getFiles(kind) { + const resp = await fetch("/getfiles/"+kind, { cache: "no-store" }); + return await resp.json(); + } + /** * * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 13d271137..6383d250b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,3 +1,5 @@ +import { api } from "./api.js"; + function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; let { min, max, step } = inputData[1]; @@ -27,6 +29,21 @@ function seedWidget(node, inputName, inputData) { return { widget: seed, randomize }; } +function reloadWidget(node, name, data) { + async function reload_callback() { + const items = data[1]; + for (let i in items) { + const w = node.widgets.find((w) => w.name === items[i][0]); + const filelist = await api.getFiles(items[i][1]); + w.options.values = filelist.files; + w.value = filelist.files[0]; + } + } + + const reload = node.addWidget("button", "RELOAD", true, function(v) { reload_callback(); }, {}); + return { reload }; +} + function addMultilineWidget(node, name, defaultVal, app) { const widget = { type: "customtext", @@ -116,6 +133,7 @@ export const ComfyWidgets = { ), }; }, + RELOAD:reloadWidget, STRING(node, inputName, inputData, app) { const defaultVal = inputData[1].default || ""; const multiline = !!inputData[1].multiline;