From 5db59edbf66eea579b9d7253c0769192ddb16ee3 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 12 Feb 2023 01:06:38 -0800 Subject: [PATCH] Scan for models up front --- nodes.py | 10 +++++----- shared.py | 9 +++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nodes.py b/nodes.py index e54e95cad..53e7f18cb 100644 --- a/nodes.py +++ b/nodes.py @@ -107,8 +107,8 @@ class CheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": ("COMBO", { "choices": shared.get_model_files("configs") }), - "ckpt_name": ("COMBO", { "choices": shared.get_model_files("checkpoints") })}} + return {"required": { "config_name": ("COMBO", { "choices": shared.all_models["configs"] }), + "ckpt_name": ("COMBO", { "choices": shared.all_models["checkpoints"] })}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -124,7 +124,7 @@ class LoraLoader: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": ("COMBO", { "choices": shared.get_model_files("loras") }), + "lora_name": ("COMBO", { "choices": shared.all_models["loras"] }), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -141,7 +141,7 @@ class LoraLoader: class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": ("COMBO", { "choices": shared.get_model_files("vae") })}} + return {"required": { "vae_name": ("COMBO", { "choices": shared.all_models["vae"] })}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -156,7 +156,7 @@ class VAELoader: class CLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": ("COMBO", { "choices": shared.get_model_files("clip") }), + return {"required": { "clip_name": ("COMBO", { "choices": shared.all_models["clip"] }), "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), }} RETURN_TYPES = ("CLIP",) diff --git a/shared.py b/shared.py index e3f330460..3eb47bc7f 100644 --- a/shared.py +++ b/shared.py @@ -21,6 +21,9 @@ model_kinds = { } +all_models = {} + + def recursive_search(directory): result = [] for root, subdir, file in os.walk(directory, followlinks=True): @@ -60,3 +63,9 @@ def find_model_file(kind, basename): config = {} with open("config.yml", "r") as f: config = yaml.safe_load(f)["config"] + + +print("Scanning for models...") +for kind in model_kinds.keys(): + all_models[kind] = get_model_files(kind) +print("Done.")