Scan for models up front

This commit is contained in:
space-nuko 2023-02-12 01:06:38 -08:00
parent 75ea299e74
commit 5db59edbf6
2 changed files with 14 additions and 5 deletions

View File

@ -107,8 +107,8 @@ class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": ("COMBO", { "choices": shared.get_model_files("configs") }),
"ckpt_name": ("COMBO", { "choices": shared.get_model_files("checkpoints") })}}
return {"required": { "config_name": ("COMBO", { "choices": shared.all_models["configs"] }),
"ckpt_name": ("COMBO", { "choices": shared.all_models["checkpoints"] })}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -124,7 +124,7 @@ class LoraLoader:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": ("COMBO", { "choices": shared.get_model_files("loras") }),
"lora_name": ("COMBO", { "choices": shared.all_models["loras"] }),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
@ -141,7 +141,7 @@ class LoraLoader:
class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": ("COMBO", { "choices": shared.get_model_files("vae") })}}
return {"required": { "vae_name": ("COMBO", { "choices": shared.all_models["vae"] })}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -156,7 +156,7 @@ class VAELoader:
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": ("COMBO", { "choices": shared.get_model_files("clip") }),
return {"required": { "clip_name": ("COMBO", { "choices": shared.all_models["clip"] }),
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
}}
RETURN_TYPES = ("CLIP",)

View File

@ -21,6 +21,9 @@ model_kinds = {
}
all_models = {}
def recursive_search(directory):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
@ -60,3 +63,9 @@ def find_model_file(kind, basename):
config = {}
with open("config.yml", "r") as f:
config = yaml.safe_load(f)["config"]
print("Scanning for models...")
for kind in model_kinds.keys():
all_models[kind] = get_model_files(kind)
print("Done.")