From 06b77605e51c5f187c9b8b9519ee0890c160f489 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sun, 26 Mar 2023 21:17:44 +0900 Subject: [PATCH] patch refresh feature for general method --- comfy_extras/nodes_upscale_model.py | 2 + nodes.py | 21 ++++++++++ server.py | 42 +------------------ web/scripts/app.js | 63 +++++------------------------ 4 files changed, 34 insertions(+), 94 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 23ee669d4..f1aaf5009 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -16,6 +16,8 @@ class UpscaleModelLoader: CATEGORY = "loaders" + REFRESH_LIST = [("model_name", "upscale_models")] + def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) sd = load_torch_file(model_path) diff --git a/nodes.py b/nodes.py index ddaf6ac37..d878dfaee 100644 --- a/nodes.py +++ b/nodes.py @@ -199,6 +199,8 @@ class CheckpointLoader: CATEGORY = "loaders" + REFRESH_LIST = [("config_name", "configs"), ("ckpt_name", "checkpoints")] + def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): config_path = folder_paths.get_full_path("configs", config_name) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) @@ -214,6 +216,8 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" + REFRESH_LIST = [("ckpt_name", "checkpoints")] + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) @@ -249,6 +253,8 @@ class LoraLoader: CATEGORY = "loaders" + REFRESH_LIST = [("lora_name", "loras")] + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) @@ -263,6 +269,8 @@ class VAELoader: CATEGORY = "loaders" + REFRESH_LIST = [("vae_name", "vae")] + #TODO: scale factor? def load_vae(self, vae_name): vae_path = folder_paths.get_full_path("vae", vae_name) @@ -279,6 +287,8 @@ class ControlNetLoader: CATEGORY = "loaders" + REFRESH_LIST = [("control_net_name", "controlnet")] + def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path) @@ -295,6 +305,8 @@ class DiffControlNetLoader: CATEGORY = "loaders" + REFRESH_LIST = [("control_net_name", "controlnet")] + def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path, model) @@ -337,6 +349,8 @@ class CLIPLoader: CATEGORY = "loaders" + REFRESH_LIST = [("clip_name", "clip")] + def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip", clip_name) clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings")) @@ -352,6 +366,8 @@ class CLIPVisionLoader: CATEGORY = "loaders" + REFRESH_LIST = [("clip_name", "clip_vision")] + def load_clip(self, clip_name): clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_vision = comfy_extras.clip_vision.load(clip_path) @@ -382,6 +398,8 @@ class StyleModelLoader: CATEGORY = "loaders" + REFRESH_LIST = [("style_model_name", "style_models")] + def load_style_model(self, style_model_name): style_model_path = folder_paths.get_full_path("style_models", style_model_name) style_model = comfy.sd.load_style_model(style_model_path) @@ -815,6 +833,9 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + + REFRESH_LIST = [("image", "input")] + def load_image(self, image): image_path = os.path.join(self.input_dir, image) i = Image.open(image_path) diff --git a/server.py b/server.py index b007f9e43..d5f0d1736 100644 --- a/server.py +++ b/server.py @@ -155,49 +155,11 @@ class PromptServer(): info['category'] = 'sd' if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY + if hasattr(obj_class, 'REFRESH_LIST'): + info['refresh_list'] = obj_class.REFRESH_LIST out[x] = info return web.json_response(out) - @routes.get("/getfiles/{kind}") - async def get_filelist(request): - out = {} - - out["files"] = {} - if "kind" in request.match_info: - kind = request.match_info["kind"] - path = "" - - # whitelist policy for security reason - if kind == "checkpoints": - path = kind - elif kind == "loras": - path = kind - elif kind == "vae": - path = kind - elif kind == "controlnet": - path = kind - elif kind == "clip": - path = kind - elif kind == "clip_vision": - path = kind - elif kind == "style_models": - path = kind - elif kind == "upscale_models": - path = kind - elif kind == "input": - path = kind # must not be empty - else: - path = "" - - if path != "": - if path == "input": - input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") - out["files"] = sorted(os.listdir(input_dir)) - else: - out["files"] = folder_paths.get_filename_list(path) - - return web.json_response(out) - @routes.get("/history") async def get_history(request): return web.json_response(self.prompt_queue.get_history()) diff --git a/web/scripts/app.js b/web/scripts/app.js index b4c126252..b4b71165a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -808,63 +808,18 @@ class ComfyApp { * Refresh file list on whole nodes */ async refreshNodes() { + const defs = await api.getNodeDefs(); + for(let nodeNum in this.graph._nodes) { const node = this.graph._nodes[nodeNum]; - - var data = []; - switch(node.type) { - case "CheckpointLoader": - data = { "config_name": "configs", - "ckpt_name": "checkpoints" }; - break; - - case "CheckpointLoaderSimple": - data = { "ckpt_name": "checkpoints" }; - break; - - case "LoraLoader": - data = { "lora_name": "loras" }; - break; - - case "VAELoader": - data = { "vae_name": "vae" }; - break; - - case "ControlNetLoader": - case "DiffControlNetLoader": - data = { "control_net_name": "controlnet" }; - break; - - case "CLIPLoader": - data = { "clip_name": "clip" }; - break; - - case "CLIPVisionLoader": - data = { "clip_name": "clip_vision" }; - break; - - case "StyleModelLoader": - data = { "style_model_name": "style_models" }; - break; - - case "LoadImage": - data = { "image": "input" }; - break; - - case "UpscaleModelLoader": - data = { "model_name": "upscale_models" }; - break; - - default: - break; - } - - for (let i in data) { - const w = node.widgets.find((w) => w.name === i); - const filelist = await api.getFiles(data[i]); - w.options.values = filelist.files; - w.value = filelist.files[0]; + const def = defs[node.type]; + for(const i in def.refresh_list) { + const item = def.refresh_list[i]; + const filelist = def.input["required"][item[0]]; + const w = node.widgets.find((w) => w.name === item[0]); + w.options.values = filelist[0]; + w.value = w.options.values[0]; } } }