diff --git a/README.md b/README.md index 8234af021..93c7b3ecf 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/) +- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models. +- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - Starts up very fast. - Works fully offline: will never download anything. diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 1769cc00d..9054a1c2e 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -786,6 +786,7 @@ class UNetModel(nn.Module): if control is not None: hsp += control.pop() h = th.cat([h, hsp], dim=1) + del hsp h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: diff --git a/comfy/model_management.py b/comfy/model_management.py index b8fd87966..8c859d3fa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -3,6 +3,7 @@ CPU = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 +HIGH_VRAM = 4 accelerate_enabled = False vram_state = NORMAL_VRAM @@ -27,10 +28,11 @@ if "--lowvram" in sys.argv: set_vram_to = LOW_VRAM if "--novram" in sys.argv: set_vram_to = NO_VRAM +if "--highvram" in sys.argv: + vram_state = HIGH_VRAM - -if set_vram_to != NORMAL_VRAM: +if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: import accelerate accelerate_enabled = True @@ -44,7 +46,7 @@ if set_vram_to != NORMAL_VRAM: total_vram_available_mb = int(max(256, total_vram_available_mb)) -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state]) current_loaded_model = None @@ -57,18 +59,24 @@ def unload_model(): global current_loaded_model global model_accelerated global current_gpu_controlnets + global vram_state + if current_loaded_model is not None: if model_accelerated: accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) model_accelerated = False - current_loaded_model.model.cpu() + #never unload models from GPU on high vram + if vram_state != HIGH_VRAM: + current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None - if len(current_gpu_controlnets) > 0: - for n in current_gpu_controlnets: - n.cpu() - current_gpu_controlnets = [] + + if vram_state != HIGH_VRAM: + if len(current_gpu_controlnets) > 0: + for n in current_gpu_controlnets: + n.cpu() + current_gpu_controlnets = [] def load_model_gpu(model): @@ -87,7 +95,7 @@ def load_model_gpu(model): current_loaded_model = model if vram_state == CPU: pass - elif vram_state == NORMAL_VRAM: + elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False real_model.cuda() else: @@ -102,6 +110,12 @@ def load_model_gpu(model): def load_controlnet_gpu(models): global current_gpu_controlnets + global vram_state + + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after + return + for m in current_gpu_controlnets: if m not in models: m.cpu() @@ -111,6 +125,19 @@ def load_controlnet_gpu(models): current_gpu_controlnets.append(m.cuda()) +def load_if_low_vram(model): + global vram_state + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + return model.cuda() + return model + +def unload_if_low_vram(model): + global vram_state + if vram_state == LOW_VRAM or vram_state == NO_VRAM: + return model.cpu() + return model + + def get_free_memory(): dev = torch.cuda.current_device() stats = torch.cuda.memory_stats(dev) diff --git a/comfy/sd.py b/comfy/sd.py index 61a01dea6..bf67f1286 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ import torch +import contextlib import sd1_clip import sd2_clip @@ -327,23 +328,38 @@ class VAE: return samples class ControlNet: - def __init__(self, control_model): + def __init__(self, control_model, device="cuda"): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 + self.device = device def get_control(self, x_noisy, t, cond_txt): + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) - print("set cond_hint", self.cond_hint.shape) - control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + + if self.control_model.dtype == torch.float16: + precision_scope = torch.autocast + else: + precision_scope = contextlib.nullcontext + + with precision_scope(self.device): + self.control_model = model_management.load_if_low_vram(self.control_model) + control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + self.control_model = model_management.unload_if_low_vram(self.control_model) + out = [] + autocast_enabled = torch.is_autocast_enabled() for x in control: x *= self.strength - return control + if x.dtype != output_dtype and not autocast_enabled: + x = x.to(output_dtype) + out.append(x) + return out def set_cond_hint(self, cond_hint, strength=1.0): self.cond_hint_original = cond_hint @@ -377,6 +393,11 @@ def load_controlnet(ckpt_path): return None context_dim = controlnet_data[key].shape[1] + + use_fp16 = False + if controlnet_data[key].dtype == torch.float16: + use_fp16 = True + control_model = cldm.ControlNet(image_size=32, in_channels=4, hint_channels=3, @@ -389,7 +410,8 @@ def load_controlnet(ckpt_path): transformer_depth=1, context_dim=context_dim, use_checkpoint=True, - legacy=False) + legacy=False, + use_fp16=use_fp16) if pth: class WeightsLoader(torch.nn.Module): diff --git a/custom_nodes/example_folder/main.py b/custom_nodes/example_folder/main.py deleted file mode 100644 index d906a5581..000000000 --- a/custom_nodes/example_folder/main.py +++ /dev/null @@ -1,87 +0,0 @@ -from utils import waste_cpu_resource -class ExampleFolder: - """ - A example node - - Class methods - ------------- - INPUT_TYPES (dict): - Tell the main program input parameters of nodes. - - Attributes - ---------- - RETURN_TYPES (`tuple`): - The type of each element in the output tulple. - FUNCTION (`str`): - The name of the entry-point method which will return a tuple. For example, if `FUNCTION = "execute"` then it will run Example().execute() - OUTPUT_NODE ([`bool`]): - WIP - CATEGORY (`str`): - WIP - execute(s) -> tuple || None: - The entry point method. The name of this method must be the same as the value of property `FUNCTION`. - For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. - """ - def __init__(self): - pass - - @classmethod - def INPUT_TYPES(s): - """ - Return a dictionary which contains config for all input fields. - The type can be a string indicate a type or a list indicate selection. - Prebuilt types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input in type "INT", "STRING" or "FLOAT" will be converted automatically from a string to the corresponse Python type before passing and have special config - Argument: s (`None`): Useless ig - Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - - Value input_fields (`dict`): Contains input fields config: - * Key field_name (`string`): Name of a entry-point method's argument - * Value field_config (`tuple`): - + First value is a string indicate the type of field or a list for selection. - + Secound value is a config for type "INT", "STRING" or "FLOAT". - """ - return { - "required": { - "string_field": ("STRING", { - "multiline": True, #Allow the input to be multilined - "default": "Hello World!" - }), - "int_field": ("INT", { - "default": 0, - "min": 0, #Minimum value - "max": 4096, #Maximum value - "step": 64 #Slider's step - }), - #Like INT - "print_to_screen": (["Enable", "Disable"], {"default": "Enable"}) - }, - #"hidden": { - # "prompt": "PROMPT", - # "extra_pnginfo": "EXTRA_PNGINFO" - #}, - } - - RETURN_TYPES = ("STRING", "INT", "FLOAT", "STRING") - FUNCTION = "test" - - #OUTPUT_NODE = True - - CATEGORY = "Example" - - def test(self, string_field, int_field, print_to_screen): - rand_float = waste_cpu_resource() - if print_to_screen == "Enable": - print(f"""Your input contains: - string_field aka input text: {string_field} - int_field: {int_field} - A random float number: {rand_float} - """) - return (string_field, int_field, rand_float, print_to_screen) - -NODE_CLASS_MAPPINGS = { - "ExampleFolder": ExampleFolder -} -""" -NODE_CLASS_MAPPINGS (dict): A dictionary contains all nodes you want to export -""" \ No newline at end of file diff --git a/custom_nodes/example_folder/utils.py b/custom_nodes/example_folder/utils.py deleted file mode 100644 index cc59f97f4..000000000 --- a/custom_nodes/example_folder/utils.py +++ /dev/null @@ -1,4 +0,0 @@ -import torch -def waste_cpu_resource(): - x = torch.rand(1, 1e6, dtype=torch.float64).cpu() - return x.numpy()[0, 1] \ No newline at end of file diff --git a/custom_nodes/example.py b/custom_nodes/example_node.py.example similarity index 61% rename from custom_nodes/example.py rename to custom_nodes/example_node.py.example index ff3a46bcc..1bb1a5a37 100644 --- a/custom_nodes/example.py +++ b/custom_nodes/example_node.py.example @@ -12,11 +12,13 @@ class Example: RETURN_TYPES (`tuple`): The type of each element in the output tulple. FUNCTION (`str`): - The name of the entry-point method which will return a tuple. For example, if `FUNCTION = "execute"` then it will run Example().execute() + The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute() OUTPUT_NODE ([`bool`]): - WIP + If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example. + The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected. + Assumed to be False if not present. CATEGORY (`str`): - WIP + The category the node should appear in the UI. execute(s) -> tuple || None: The entry point method. The name of this method must be the same as the value of property `FUNCTION`. For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`. @@ -28,10 +30,10 @@ class Example: def INPUT_TYPES(s): """ Return a dictionary which contains config for all input fields. - The type can be a string indicate a type or a list indicate selection. - Prebuilt types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". - Input in type "INT", "STRING" or "FLOAT" will be converted automatically from a string to the corresponse Python type before passing and have special config - Argument: s (`None`): Useless ig + Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT". + Input types "INT", "STRING" or "FLOAT" are special values for fields on the node. + The type can be a list for selection. + Returns: `dict`: - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` - Value input_fields (`dict`): Contains input fields config: @@ -42,46 +44,43 @@ class Example: """ return { "required": { - "string_field": ("STRING", { - "multiline": True, #Allow the input to be multilined - "default": "Hello World!" - }), + "image": ("IMAGE",), "int_field": ("INT", { "default": 0, "min": 0, #Minimum value "max": 4096, #Maximum value "step": 64 #Slider's step }), - #Like INT "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "print_to_screen": (["Enable", "Disable"], {"default": "Enable"}) + "print_to_screen": (["enable", "disable"],), + "string_field": ("STRING", { + "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node + "default": "Hello World!" + }), }, - #"hidden": { - # "prompt": "PROMPT", - # "extra_pnginfo": "EXTRA_PNGINFO" - #}, } - RETURN_TYPES = ("STRING", "INT", "FLOAT", "STRING") + RETURN_TYPES = ("IMAGE",) FUNCTION = "test" - #OUTPUT_NODE = True + #OUTPUT_NODE = False CATEGORY = "Example" - def test(self, string_field, int_field, float_field, print_to_screen): - if print_to_screen == "Enable": + def test(self, image, string_field, int_field, float_field, print_to_screen): + if print_to_screen == "enable": print(f"""Your input contains: string_field aka input text: {string_field} int_field: {int_field} float_field: {float_field} """) - return (string_field, int_field, float_field, print_to_screen) + #do some processing on the image, in this example I just invert it + image = 1.0 - image + return (image,) +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique NODE_CLASS_MAPPINGS = { "Example": Example } -""" -NODE_CLASS_MAPPINGS (dict): A dictionary contains all nodes you want to export -""" \ No newline at end of file diff --git a/main.py b/main.py index f5aec4424..a162e1ed3 100644 --- a/main.py +++ b/main.py @@ -29,6 +29,7 @@ if __name__ == "__main__": print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print() + print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.") print("\t--novram\t\t\tWhen lowvram isn't enough.") diff --git a/nodes.py b/nodes.py index aafbd31a3..ed30c0ca5 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import sys import json import hashlib import copy +import traceback from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -774,7 +775,7 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): return {"required": - {"image": (os.listdir(s.input_dir), ), + {"image": (sorted(os.listdir(s.input_dir)), ), "channel": (["alpha", "red", "green", "blue"], ),} } @@ -861,29 +862,28 @@ NODE_CLASS_MAPPINGS = { CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") def load_custom_nodes(): possible_modules = os.listdir(CUSTOM_NODE_PATH) - try: - #Comment out these two lines if you want to test - possible_modules.remove("example.py") - possible_modules.remove("example_folder") + if "__pycache__" in possible_modules: possible_modules.remove("__pycache__") - except ValueError: pass + for possible_module in possible_modules: module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - + + module_name = "custom_node_module.{}".format(possible_module) try: if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location(os.path.basename(module_path), module_path) + module_spec = importlib.util.spec_from_file_location(module_name, module_path) else: - module_spec = importlib.util.spec_from_file_location(module_path, "main.py") + module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) module = importlib.util.module_from_spec(module_spec) + sys.modules[module_name] = module module_spec.loader.exec_module(module) - if getattr(module, "NODE_CLASS_MAPPINGS") is not None: + if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) else: print(f"Skip {possible_module} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") - except ImportError as e: - print(f"Cannot import {possible_module} module for custom nodes.") - print(e) + except Exception as e: + print(traceback.format_exc()) + print(f"Cannot import {possible_module} module for custom nodes:", e) load_custom_nodes() \ No newline at end of file diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index cdf182b8d..2e364f165 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -85,7 +85,7 @@ { "cell_type": "markdown", "source": [ - "Run ComfyUI:" + "Run ComfyUI (use the fp16 model configs for more speed):" ], "metadata": { "id": "gggggggggg" @@ -112,7 +112,7 @@ "\n", "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", "\n", - "!python main.py" + "!python main.py --highvram" ], "metadata": { "id": "hhhhhhhhhh"