models paths

This commit is contained in:
ljleb 2023-03-13 22:01:51 -04:00
parent 0e836d525e
commit d89952d090
3 changed files with 39 additions and 17 deletions

View File

@ -2,7 +2,7 @@ import os
from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import comfy.model_management
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions, extract_arg_values
import torch
import comfy.utils
@ -12,7 +12,7 @@ class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ),
return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir, *extract_arg_values('--upscaler-dir')), supported_ckpt_extensions), ),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
FUNCTION = "load_model"

13
main.py
View File

@ -13,6 +13,19 @@ if __name__ == "__main__":
print("Valid Command line Arguments:")
print("\t--listen\t\t\tListen on 0.0.0.0 so the UI can be accessed from other computers.")
print("\t--port 8188\t\t\tSet the listen port.")
s = os.path.sep
print(f"\t--ckpt-dir path{s}to{s}dir\t\t\tAdd a path to a checkpoints dir.")
print(f"\t--clip-dir path{s}to{s}dir\t\t\tAdd a path to a clip dir.")
print(f"\t--clip-vision-dir path{s}to{s}dir\t\t\tAdd a path to a clip vision dir.")
print(f"\t--controlnet-dir path{s}to{s}dir\t\t\tAdd a path to a controlnet checkpoints dir.")
print(f"\t--embed-dir path{s}to{s}dir\t\t\tAdd a path to an embeddings dir.")
print(f"\t--lora-dir path{s}to{s}dir\t\t\tAdd a path to a lora dir.")
print(f"\t--style-model-dir path{s}to{s}dir\t\t\tAdd a path to a style models dir.")
print(f"\t--t2i-dir path{s}to{s}dir\t\t\tAdd a path to a T2I style adapter dir.")
print(f"\t--upscaler-dir path{s}to{s}dir\t\t\tAdd a path to an upscale models dir.")
print(f"\t--vae-dir path{s}to{s}dir\t\t\tAdd a path to a vae dir.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")

View File

@ -32,12 +32,21 @@ try:
except:
print("Could not import safetensors, safetensors support disabled.")
def recursive_search(directory):
def extract_arg_values(option):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
for i in range(len(sys.argv) - 1):
if sys.argv[i] == option:
result.append(sys.argv[i + 1])
i += 1
return result
def recursive_search(*directories):
result = []
for directory in directories:
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
def filter_files_extensions(files, extensions):
@ -214,7 +223,7 @@ class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir, *extract_arg_values('--ckpt-dir')), supported_ckpt_extensions), )}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -231,7 +240,7 @@ class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir, *extract_arg_values('--ckpt-dir')), supported_ckpt_extensions), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -266,7 +275,7 @@ class LoraLoader:
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir, *extract_arg_values('--lora-dir')), supported_pt_extensions), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
@ -285,7 +294,7 @@ class VAELoader:
vae_dir = os.path.join(models_dir, "vae")
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir, *extract_arg_values('--vae-dir')), supported_pt_extensions), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -302,7 +311,7 @@ class ControlNetLoader:
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir, *extract_arg_values('--controlnet-dir')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -320,7 +329,7 @@ class DiffControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir, *extract_arg_values('--controlnet-dir')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -364,7 +373,7 @@ class T2IAdapterLoader:
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
@classmethod
def INPUT_TYPES(s):
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir, *extract_arg_values('--t2i-dir')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_t2i_adapter"
@ -381,7 +390,7 @@ class CLIPLoader:
clip_dir = os.path.join(models_dir, "clip")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir, *extract_arg_values('--clip-dir')), supported_pt_extensions), ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -398,7 +407,7 @@ class CLIPVisionLoader:
clip_dir = os.path.join(models_dir, "clip_vision")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir, *extract_arg_values('--clip-vision-dir')), supported_pt_extensions), ),
}}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
@ -430,7 +439,7 @@ class StyleModelLoader:
style_model_dir = os.path.join(models_dir, "style_models")
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir, *extract_arg_values('--style-model-dir')), supported_pt_extensions), )}}
RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model"