diff --git a/README.md b/README.md index 925caa732..6bef25cee 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae +Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier. + ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ffae81c49..35d44164f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") +parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 3d42d7806..1e3fc9359 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -1,6 +1,7 @@ import comfy.sd import comfy.utils import comfy.model_base +import comfy.model_management import folder_paths import json @@ -178,6 +179,39 @@ class CheckpointSave: comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) return {} +class VAESave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) + return {} NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, @@ -186,4 +220,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, + "VAESave": VAESave, } diff --git a/folder_paths.py b/folder_paths.py index 4a10c68e7..898513b0e 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -46,6 +46,10 @@ def set_temp_directory(temp_dir): global temp_directory temp_directory = temp_dir +def set_input_directory(input_dir): + global input_directory + input_directory = input_dir + def get_output_directory(): global output_directory return output_directory diff --git a/main.py b/main.py index 7c5eaee0a..875ea1aa9 100644 --- a/main.py +++ b/main.py @@ -175,6 +175,11 @@ if __name__ == "__main__": print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) + if args.input_directory: + input_dir = os.path.abspath(args.input_directory) + print(f"Setting input directory to: {input_dir}") + folder_paths.set_input_directory(input_dir) + if args.quick_test_for_ci: exit(0)