Added file reload feature to widgets.

This commit is contained in:
Lt.Dr.Data 2023-03-20 15:41:02 +09:00
parent ae87d8816f
commit 468dd38628
5 changed files with 90 additions and 7 deletions

View File

@ -10,6 +10,7 @@ class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
"RELOAD": ("RELOAD", [("model_name", "upscale_models")]),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model"

View File

@ -191,7 +191,8 @@ class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), )}}
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"RELOAD": ("RELOAD", [("config_name", "configs"), ("ckpt_name", "checkpoints")]) }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -206,7 +207,7 @@ class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
"RELOAD": ("RELOAD", [("ckpt_name", "checkpoints")]) }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -241,6 +242,7 @@ class LoraLoader:
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"RELOAD": ("RELOAD", [("lora_name", "loras")])
}}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
@ -255,7 +257,9 @@ class LoraLoader:
class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), ),
"RELOAD": ("RELOAD", [("vae_name", "vae")])
}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -270,7 +274,9 @@ class VAELoader:
class ControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), ),
"RELOAD": ("RELOAD", [("control_net_name", "controlnet")])
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -286,7 +292,9 @@ class DiffControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
"RELOAD": ("RELOAD", [("control_net_name", "controlnet")])
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -329,6 +337,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"RELOAD": ("RELOAD", [("clip_name", "clip")])
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -344,6 +353,7 @@ class CLIPVisionLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
"RELOAD": ("RELOAD", [("clip_name", "clip_vision")])
}}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
@ -373,7 +383,9 @@ class CLIPVisionEncode:
class StyleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), ),
"RELOAD": ("RELOAD", [("style_model_name", "style_models")])
}}
RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model"
@ -790,7 +802,10 @@ class LoadImage:
if not os.path.exists(s.input_dir):
os.makedirs(s.input_dir)
return {"required":
{"image": (sorted(os.listdir(s.input_dir)), )},
{"image": (sorted(os.listdir(s.input_dir)), ),
"RELOAD": ("RELOAD", [("image", "input")])
},
}
CATEGORY = "image"

View File

@ -151,6 +151,46 @@ class PromptServer():
out[x] = info
return web.json_response(out)
@routes.get("/getfiles/{kind}")
async def get_filelist(request):
out = {}
out["files"] = {}
if "kind" in request.match_info:
kind = request.match_info["kind"]
path = ""
# whitelist policy for security reason
if kind == "checkpoints":
path = kind
elif kind == "loras":
path = kind
elif kind == "vae":
path = kind
elif kind == "controlnet":
path = kind
elif kind == "clip":
path = kind
elif kind == "clip_vision":
path = kind
elif kind == "style_models":
path = kind
elif kind == "upscale_models":
path = kind
elif kind == "input":
path = kind # must not be empty
else:
path = ""
if path != "":
if path == "input":
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
out["files"] = sorted(os.listdir(input_dir))
else:
out["files"] = folder_paths.get_filename_list(path)
return web.json_response(out)
@routes.get("/history")
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())

View File

@ -124,6 +124,15 @@ class ComfyApi extends EventTarget {
return await resp.json();
}
/**
* Loads file list
* @returns An array of file list
*/
async getFiles(kind) {
const resp = await fetch("/getfiles/"+kind, { cache: "no-store" });
return await resp.json();
}
/**
*
* @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue

View File

@ -1,3 +1,5 @@
import { api } from "./api.js";
function getNumberDefaults(inputData, defaultStep) {
let defaultVal = inputData[1]["default"];
let { min, max, step } = inputData[1];
@ -27,6 +29,21 @@ function seedWidget(node, inputName, inputData) {
return { widget: seed, randomize };
}
function reloadWidget(node, name, data) {
async function reload_callback() {
const items = data[1];
for (let i in items) {
const w = node.widgets.find((w) => w.name === items[i][0]);
const filelist = await api.getFiles(items[i][1]);
w.options.values = filelist.files;
w.value = filelist.files[0];
}
}
const reload = node.addWidget("button", "RELOAD", true, function(v) { reload_callback(); }, {});
return { reload };
}
function addMultilineWidget(node, name, defaultVal, app) {
const widget = {
type: "customtext",
@ -116,6 +133,7 @@ export const ComfyWidgets = {
),
};
},
RELOAD:reloadWidget,
STRING(node, inputName, inputData, app) {
const defaultVal = inputData[1].default || "";
const multiline = !!inputData[1].multiline;