Merge branch 'comfyanonymous:master' into controlnet-annotator

This commit is contained in:
Fannovel16 2023-02-18 18:43:12 +07:00 committed by GitHub
commit c4048cc39d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 107 additions and 146 deletions

View File

@ -20,6 +20,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- Starts up very fast.
- Works fully offline: will never download anything.

View File

@ -786,6 +786,7 @@ class UNetModel(nn.Module):
if control is not None:
hsp += control.pop()
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:

View File

@ -3,6 +3,7 @@ CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
accelerate_enabled = False
vram_state = NORMAL_VRAM
@ -27,10 +28,11 @@ if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to != NORMAL_VRAM:
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try:
import accelerate
accelerate_enabled = True
@ -44,7 +46,7 @@ if set_vram_to != NORMAL_VRAM:
total_vram_available_mb = int(max(256, total_vram_available_mb))
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state])
current_loaded_model = None
@ -57,18 +59,24 @@ def unload_model():
global current_loaded_model
global model_accelerated
global current_gpu_controlnets
global vram_state
if current_loaded_model is not None:
if model_accelerated:
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False
current_loaded_model.model.cpu()
#never unload models from GPU on high vram
if vram_state != HIGH_VRAM:
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
if vram_state != HIGH_VRAM:
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
def load_model_gpu(model):
@ -87,7 +95,7 @@ def load_model_gpu(model):
current_loaded_model = model
if vram_state == CPU:
pass
elif vram_state == NORMAL_VRAM:
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
model_accelerated = False
real_model.cuda()
else:
@ -102,6 +110,12 @@ def load_model_gpu(model):
def load_controlnet_gpu(models):
global current_gpu_controlnets
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
@ -111,6 +125,19 @@ def load_controlnet_gpu(models):
current_gpu_controlnets.append(m.cuda())
def load_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cuda()
return model
def unload_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
return model.cpu()
return model
def get_free_memory():
dev = torch.cuda.current_device()
stats = torch.cuda.memory_stats(dev)

View File

@ -1,4 +1,5 @@
import torch
import contextlib
import sd1_clip
import sd2_clip
@ -327,23 +328,38 @@ class VAE:
return samples
class ControlNet:
def __init__(self, control_model):
def __init__(self, control_model, device="cuda"):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.device = device
def get_control(self, x_noisy, t, cond_txt):
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
print("set cond_hint", self.cond_hint.shape)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if self.control_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
with precision_scope(self.device):
self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = []
autocast_enabled = torch.is_autocast_enabled()
for x in control:
x *= self.strength
return control
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)
out.append(x)
return out
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
@ -377,6 +393,11 @@ def load_controlnet(ckpt_path):
return None
context_dim = controlnet_data[key].shape[1]
use_fp16 = False
if controlnet_data[key].dtype == torch.float16:
use_fp16 = True
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
@ -389,7 +410,8 @@ def load_controlnet(ckpt_path):
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
legacy=False)
legacy=False,
use_fp16=use_fp16)
if pth:
class WeightsLoader(torch.nn.Module):

View File

@ -1,87 +0,0 @@
from utils import waste_cpu_resource
class ExampleFolder:
"""
A example node
Class methods
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tulple.
FUNCTION (`str`):
The name of the entry-point method which will return a tuple. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
WIP
CATEGORY (`str`):
WIP
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
The type can be a string indicate a type or a list indicate selection.
Prebuilt types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input in type "INT", "STRING" or "FLOAT" will be converted automatically from a string to the corresponse Python type before passing and have special config
Argument: s (`None`): Useless ig
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Secound value is a config for type "INT", "STRING" or "FLOAT".
"""
return {
"required": {
"string_field": ("STRING", {
"multiline": True, #Allow the input to be multilined
"default": "Hello World!"
}),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64 #Slider's step
}),
#Like INT
"print_to_screen": (["Enable", "Disable"], {"default": "Enable"})
},
#"hidden": {
# "prompt": "PROMPT",
# "extra_pnginfo": "EXTRA_PNGINFO"
#},
}
RETURN_TYPES = ("STRING", "INT", "FLOAT", "STRING")
FUNCTION = "test"
#OUTPUT_NODE = True
CATEGORY = "Example"
def test(self, string_field, int_field, print_to_screen):
rand_float = waste_cpu_resource()
if print_to_screen == "Enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
int_field: {int_field}
A random float number: {rand_float}
""")
return (string_field, int_field, rand_float, print_to_screen)
NODE_CLASS_MAPPINGS = {
"ExampleFolder": ExampleFolder
}
"""
NODE_CLASS_MAPPINGS (dict): A dictionary contains all nodes you want to export
"""

View File

@ -1,4 +0,0 @@
import torch
def waste_cpu_resource():
x = torch.rand(1, 1e6, dtype=torch.float64).cpu()
return x.numpy()[0, 1]

View File

@ -12,11 +12,13 @@ class Example:
RETURN_TYPES (`tuple`):
The type of each element in the output tulple.
FUNCTION (`str`):
The name of the entry-point method which will return a tuple. For example, if `FUNCTION = "execute"` then it will run Example().execute()
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
WIP
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
WIP
The category the node should appear in the UI.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
@ -28,10 +30,10 @@ class Example:
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
The type can be a string indicate a type or a list indicate selection.
Prebuilt types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input in type "INT", "STRING" or "FLOAT" will be converted automatically from a string to the corresponse Python type before passing and have special config
Argument: s (`None`): Useless ig
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
@ -42,46 +44,43 @@ class Example:
"""
return {
"required": {
"string_field": ("STRING", {
"multiline": True, #Allow the input to be multilined
"default": "Hello World!"
}),
"image": ("IMAGE",),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64 #Slider's step
}),
#Like INT
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"print_to_screen": (["Enable", "Disable"], {"default": "Enable"})
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!"
}),
},
#"hidden": {
# "prompt": "PROMPT",
# "extra_pnginfo": "EXTRA_PNGINFO"
#},
}
RETURN_TYPES = ("STRING", "INT", "FLOAT", "STRING")
RETURN_TYPES = ("IMAGE",)
FUNCTION = "test"
#OUTPUT_NODE = True
#OUTPUT_NODE = False
CATEGORY = "Example"
def test(self, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "Enable":
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
int_field: {int_field}
float_field: {float_field}
""")
return (string_field, int_field, float_field, print_to_screen)
#do some processing on the image, in this example I just invert it
image = 1.0 - image
return (image,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Example": Example
}
"""
NODE_CLASS_MAPPINGS (dict): A dictionary contains all nodes you want to export
"""

View File

@ -29,6 +29,7 @@ if __name__ == "__main__":
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
print("\t--novram\t\t\tWhen lowvram isn't enough.")

View File

@ -5,6 +5,7 @@ import sys
import json
import hashlib
import copy
import traceback
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@ -774,7 +775,7 @@ class LoadImageMask:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"image": (os.listdir(s.input_dir), ),
{"image": (sorted(os.listdir(s.input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
@ -861,29 +862,28 @@ NODE_CLASS_MAPPINGS = {
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
def load_custom_nodes():
possible_modules = os.listdir(CUSTOM_NODE_PATH)
try:
#Comment out these two lines if you want to test
possible_modules.remove("example.py")
possible_modules.remove("example_folder")
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
except ValueError: pass
for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
module_name = "custom_node_module.{}".format(possible_module)
try:
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(module_path), module_path)
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
else:
module_spec = importlib.util.spec_from_file_location(module_path, "main.py")
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
if getattr(module, "NODE_CLASS_MAPPINGS") is not None:
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
else:
print(f"Skip {possible_module} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
except ImportError as e:
print(f"Cannot import {possible_module} module for custom nodes.")
print(e)
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {possible_module} module for custom nodes:", e)
load_custom_nodes()

View File

@ -85,7 +85,7 @@
{
"cell_type": "markdown",
"source": [
"Run ComfyUI:"
"Run ComfyUI (use the fp16 model configs for more speed):"
],
"metadata": {
"id": "gggggggggg"
@ -112,7 +112,7 @@
"\n",
"threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n",
"\n",
"!python main.py"
"!python main.py --highvram"
],
"metadata": {
"id": "hhhhhhhhhh"