diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 98e9863e1..d174b32a2 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -2,7 +2,7 @@ import os from comfy_extras.chainner_models import model_loading from comfy.sd import load_torch_file import comfy.model_management -from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions +from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions, extract_arg_values import torch import comfy.utils @@ -12,7 +12,7 @@ class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir), supported_ckpt_extensions), ), + return {"required": { "model_name": (filter_files_extensions(recursive_search(s.upscale_model_dir, *extract_arg_values('--upscaler-dir')), supported_ckpt_extensions), ), }} RETURN_TYPES = ("UPSCALE_MODEL",) FUNCTION = "load_model" diff --git a/main.py b/main.py index b2b3f1c40..1c3133d0c 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,19 @@ if __name__ == "__main__": print("Valid Command line Arguments:") print("\t--listen\t\t\tListen on 0.0.0.0 so the UI can be accessed from other computers.") print("\t--port 8188\t\t\tSet the listen port.") + + s = os.path.sep + print(f"\t--ckpt-dir path{s}to{s}dir\t\t\tAdd a path to a checkpoints dir.") + print(f"\t--clip-dir path{s}to{s}dir\t\t\tAdd a path to a clip dir.") + print(f"\t--clip-vision-dir path{s}to{s}dir\t\t\tAdd a path to a clip vision dir.") + print(f"\t--controlnet-dir path{s}to{s}dir\t\t\tAdd a path to a controlnet checkpoints dir.") + print(f"\t--embed-dir path{s}to{s}dir\t\t\tAdd a path to an embeddings dir.") + print(f"\t--lora-dir path{s}to{s}dir\t\t\tAdd a path to a lora dir.") + print(f"\t--style-model-dir path{s}to{s}dir\t\t\tAdd a path to a style models dir.") + print(f"\t--t2i-dir path{s}to{s}dir\t\t\tAdd a path to a T2I style adapter dir.") + print(f"\t--upscaler-dir path{s}to{s}dir\t\t\tAdd a path to an upscale models dir.") + print(f"\t--vae-dir path{s}to{s}dir\t\t\tAdd a path to a vae dir.") + print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.") diff --git a/nodes.py b/nodes.py index 0a0a0a9cd..b33e7fd0b 100644 --- a/nodes.py +++ b/nodes.py @@ -32,12 +32,21 @@ try: except: print("Could not import safetensors, safetensors support disabled.") -def recursive_search(directory): +def extract_arg_values(option): result = [] - for root, subdir, file in os.walk(directory, followlinks=True): - for filepath in file: - #we os.path,join directory with a blank string to generate a path separator at the end. - result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) + for i in range(len(sys.argv) - 1): + if sys.argv[i] == option: + result.append(sys.argv[i + 1]) + i += 1 + return result + +def recursive_search(*directories): + result = [] + for directory in directories: + for root, subdir, file in os.walk(directory, followlinks=True): + for filepath in file: + #we os.path,join directory with a blank string to generate a path separator at the end. + result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) return result def filter_files_extensions(files, extensions): @@ -214,7 +223,7 @@ class CheckpointLoader: @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ), - "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}} + "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir, *extract_arg_values('--ckpt-dir')), supported_ckpt_extensions), )}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -231,7 +240,7 @@ class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), + return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir, *extract_arg_values('--ckpt-dir')), supported_ckpt_extensions), ), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -266,7 +275,7 @@ class LoraLoader: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ), + "lora_name": (filter_files_extensions(recursive_search(s.lora_dir, *extract_arg_values('--lora-dir')), supported_pt_extensions), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -285,7 +294,7 @@ class VAELoader: vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} + return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir, *extract_arg_values('--vae-dir')), supported_pt_extensions), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -302,7 +311,7 @@ class ControlNetLoader: controlnet_dir = os.path.join(models_dir, "controlnet") @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir, *extract_arg_values('--controlnet-dir')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -320,7 +329,7 @@ class DiffControlNetLoader: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir, *extract_arg_values('--controlnet-dir')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -364,7 +373,7 @@ class T2IAdapterLoader: t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter") @classmethod def INPUT_TYPES(s): - return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}} + return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir, *extract_arg_values('--t2i-dir')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_t2i_adapter" @@ -381,7 +390,7 @@ class CLIPLoader: clip_dir = os.path.join(models_dir, "clip") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir, *extract_arg_values('--clip-dir')), supported_pt_extensions), ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -398,7 +407,7 @@ class CLIPVisionLoader: clip_dir = os.path.join(models_dir, "clip_vision") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir, *extract_arg_values('--clip-vision-dir')), supported_pt_extensions), ), }} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -430,7 +439,7 @@ class StyleModelLoader: style_model_dir = os.path.join(models_dir, "style_models") @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}} + return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir, *extract_arg_values('--style-model-dir')), supported_pt_extensions), )}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model"