diff --git a/.gitignore b/.gitignore index 98d91318d..b3c9bd6e6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,10 @@ __pycache__/ /models/ /temp/ /custom_nodes/ +cache +models/ +temp/ +/custom_nodes/.pytest_cache !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs @@ -14,3 +18,5 @@ venv/ /web/extensions/* !/web/extensions/logging.js.example !/web/extensions/core/ +/workflows +**/comfyui_controlnet_aux diff --git a/custom_nodes/clip_interrogator.py b/custom_nodes/clip_interrogator.py new file mode 100644 index 000000000..baac3472e --- /dev/null +++ b/custom_nodes/clip_interrogator.py @@ -0,0 +1,58 @@ +import os +import random +import sys +import hashlib +import base64 + +from clip_interrogator import Interrogator, Config +from torch import Tensor +import torchvision.transforms as T +from PIL import Image + +class ClipInterrogator: + MODEL_NAME = ["ViT-L-14/openai"] + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "clip": ("CLIP",), + "model_name": (ClipInterrogator.MODEL_NAME,), + } + } + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "clip_interrogate" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + VALUE = "" + + @classmethod + def IS_CHANGED(s, image, clip, model_name): + # TODO: Why does this not cache immidiately + return hashlib.md5(str(bytearray(image.numpy())).encode("utf-8")).hexdigest() + + def clip_interrogate(self, image, clip, model_name): + img_tensor = image[0] + # define a transform to convert a tensor to PIL image + transform = T.ToPILImage() + h, w, c = img_tensor.size() + # print(h,w,c) + # convert the tensor to PIL image using above transform + img = transform(image[0].reshape(c, h, w)) # Reshape since Tensor is using Height, Width, Color but Image needs C, H, W + config = Config(clip_model_name=model_name) + config.apply_low_vram_defaults() + ci = Interrogator(config) + ClipInterrogator.VALUE = ci.interrogate(img) + print("Image:", ClipInterrogator.VALUE) + tokens = clip.tokenize(ClipInterrogator.VALUE) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "ClipInterrogator": ClipInterrogator +} \ No newline at end of file diff --git a/custom_nodes/debug_cond.py b/custom_nodes/debug_cond.py new file mode 100644 index 000000000..b7f1ed944 --- /dev/null +++ b/custom_nodes/debug_cond.py @@ -0,0 +1,49 @@ +import datetime +import math +import os +import random + +import PIL +import einops +import torch +from torch import Tensor +import matplotlib.pyplot as plt +import torchvision.transforms as T + +class DebugCond: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP",), + "cond_input": ("CONDITIONING",), + } + } + + RETURN_TYPES = ("CONDITIONING", "IMAGE",) + FUNCTION = "debug_node" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + @classmethod + def IS_CHANGED(s, clip, cond_input): + # TODO: Why does this not cache immidiately + return random.randint(0, 10000) + + def debug_node(self, clip, cond_input): + # print("Cond Shape:", cond_input[0][0].shape) + # signal = cond_input[0][0].reshape(-1) + # stripped_signal = signal[::2048] + plt.plot(cond_input[0][0][0]) + img = PIL.Image.frombytes('RGB', plt.gcf().canvas.get_width_height(), plt.gcf().canvas.tostring_rgb()) + img_tensor = T.PILToTensor()(img) / 255.0 + img_tensor = einops.reduce(img_tensor, "a b c -> 1 b c a", "max") + return cond_input, img_tensor + +NODE_CLASS_MAPPINGS = { + "DebugCond": DebugCond +} + +# TODO: Impl into execution.py +SCRIPT_TEMPLATE_PATH = os.path.join(os.path.join(__file__, os.pardir), "debug_cond.js") diff --git a/custom_nodes/debug_latent.py b/custom_nodes/debug_latent.py new file mode 100644 index 000000000..30915dc73 --- /dev/null +++ b/custom_nodes/debug_latent.py @@ -0,0 +1,34 @@ +import math + +import torch +import torchvision.transforms as T +from PIL.Image import Image + + +class DebugLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"latent": ("LATENT",), } + } + + RETURN_TYPES = ("LATENT", "LATENT",) + FUNCTION = "latent_space" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + def latent_space(self, latent): + x = latent["samples"] + transformer = T.ToPILImage() + img: Image = transformer(x[0]) + # img.show() + # y = x * 0.75 - x * 0.25 + torch.rand(x.shape) * 0.1 + y = x * 0.5 + torch.rand(x.shape) * 0.5 + modified_latent = {"samples": y} + return (latent, modified_latent) + + +NODE_CLASS_MAPPINGS = { + "DebugLatent": DebugLatent +} diff --git a/custom_nodes/debug_model.py b/custom_nodes/debug_model.py new file mode 100644 index 000000000..a5635c083 --- /dev/null +++ b/custom_nodes/debug_model.py @@ -0,0 +1,28 @@ +import os + +from torch import Tensor + + +class DebugModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_input": ("MODEL",), + } + } + + RETURN_TYPES = () + FUNCTION = "debug_node" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + def debug_node(self, model_input): + print("Model:", model_input) + return {} + + +NODE_CLASS_MAPPINGS = { + "DebugModel": DebugModel +} \ No newline at end of file diff --git a/custom_nodes/debug_node.py b/custom_nodes/debug_node.py new file mode 100644 index 000000000..a92706e9f --- /dev/null +++ b/custom_nodes/debug_node.py @@ -0,0 +1,23 @@ +class DebugNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond_input": ("CONDITIONING",), + "text": ("STRING", { "default": "" }), + }, + } + + RETURN_TYPES = () + FUNCTION = "debug_node" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + def debug_node(self, cond_input, text): + return { "ui": { "texts": ["ABC"] } } + + +NODE_CLASS_MAPPINGS = { + "DebugNode": DebugNode +} \ No newline at end of file diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index e37808b03..fe6b1caee 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -6,6 +6,8 @@ class Example: ------------- INPUT_TYPES (dict): Tell the main program input parameters of nodes. + IS_CHANGED (dict) -> str: + Tells the prompt loop if the current node has change on new execution based on a string identifier Attributes ---------- @@ -37,7 +39,8 @@ class Example: The type can be a list for selection. Returns: `dict`: - - Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required` + - Key input_fields_group (`string`): Can be either required, hidden or optional. + - A node class must have property `required` - Value input_fields (`dict`): Contains input fields config: * Key field_name (`string`): Name of a entry-point method's argument * Value field_config (`tuple`): diff --git a/custom_nodes/image_crop.py b/custom_nodes/image_crop.py new file mode 100644 index 000000000..034aae567 --- /dev/null +++ b/custom_nodes/image_crop.py @@ -0,0 +1,63 @@ +import math + +import einops +import torch +import torchvision.transforms as T +from PIL import ImageFilter +from PIL.Image import Image + +import nodes + + +class ImageCrop: + @classmethod + def INPUT_TYPES(s): + return {"required": + { + "vae": ("VAE",), + "latent": ("LATENT",), + "center_x": ("INT", { + "default": 0, + "min": 0, # Minimum value + "max": 4096, # Maximum value + "step": 16, # Slider's step + }), + "center_y": ("INT", { + "default": 0, + "min": 0, # Minimum value + "max": 4096, # Maximum value + "step": 16, # Slider's step + }), + "pixelradius": ("INT", { + "default": 0, + "min": 0, # Minimum value + "max": 4096, # Maximum value + "step": 16, # Slider's step + }) + } + } + + RETURN_TYPES = ("LATENT", "IMAGE",) + + FUNCTION = "image_crop" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + def image_crop(self, vae, latent, center_x, center_y, pixelradius): + tensor_img = vae.decode(latent["samples"]) + stripped_tensor_img = tensor_img[0] + h, w, c = stripped_tensor_img.size() + pil_img: Image = T.ToPILImage()(einops.rearrange(stripped_tensor_img, "h w c -> c h w")) + nw, nh = center_x + pixelradius / 2, center_y + pixelradius / 2 + pil_img = pil_img.crop((center_x - pixelradius / 2, center_y - pixelradius / 2, nw, nh)) + new_tensor_img = einops.reduce(T.ToTensor()(pil_img), "c h w -> 1 h w c", "max") + # new_tensor_img = new_stripped_tensor_img.permute(0, 1, 2, 3) + pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img) + new_latent = vae.encode(pixels[:, :, :, :3]) + return ({"samples": new_latent}, new_tensor_img) + + +NODE_CLASS_MAPPINGS = { + "ImageCrop": ImageCrop +} diff --git a/custom_nodes/image_fx.py b/custom_nodes/image_fx.py new file mode 100644 index 000000000..a744e1cd1 --- /dev/null +++ b/custom_nodes/image_fx.py @@ -0,0 +1,43 @@ +import math + +import torch +import torchvision.transforms as T +from PIL import ImageFilter +from PIL.Image import Image + +import nodes + + +class ImageFX: + @classmethod + def INPUT_TYPES(s): + return {"required": + { + "vae": ("VAE",), + "latent": ("LATENT",), + } + } + + RETURN_TYPES = ("LATENT", "IMAGE",) + + FUNCTION = "image_fx" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + def image_fx(self, vae, latent): + tensor_img = vae.decode(latent["samples"]) + stripped_tensor_img = tensor_img[0] + h, w, c = stripped_tensor_img.size() + pil_img: Image = T.ToPILImage()(stripped_tensor_img.reshape(c, h, w)) + pil_img = pil_img.filter(ImageFilter.ModeFilter(2)) + new_stripped_tensor_img = T.PILToTensor()(pil_img) / 255.0 + new_tensor_img = new_stripped_tensor_img.reshape(1, h, w, c) + pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img) + new_latent = vae.encode(pixels[:, :, :, :3]) + return ({"samples": new_latent}, new_tensor_img) + + +NODE_CLASS_MAPPINGS = { + "ImageFX": ImageFX +} diff --git a/custom_nodes/test_generator.py b/custom_nodes/test_generator.py new file mode 100644 index 000000000..1873c60c0 --- /dev/null +++ b/custom_nodes/test_generator.py @@ -0,0 +1,42 @@ +import random + + +class TestGenerator: + + def __init__(self): + self.testID = 0 + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP",), + }, + "hidden": { + "testId": ("STRING",), + } + } + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "test_generator" + OUTPUT_NODE = True + + CATEGORY = "inflamously" + + TESTID = 0 + @classmethod + def IS_CHANGED(s, clip, testId=None): + # intValue = random.randint(0, 100) + # value = str(intValue) + if TestGenerator.TESTID < 2: + TestGenerator.TESTID += 1 + return str(TestGenerator.TESTID) + + def test_generator(self, clip, testId=None): + tokens = clip.tokenize("test") + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled}]], ) + +NODE_CLASS_MAPPINGS = { + "TestGenerator": TestGenerator +} \ No newline at end of file diff --git a/custom_nodes/test_node.py b/custom_nodes/test_node.py new file mode 100644 index 000000000..f2897e06c --- /dev/null +++ b/custom_nodes/test_node.py @@ -0,0 +1,48 @@ +from transformers.models import clip + + +class TestNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP", ), + "image": ("IMAGE",), + "int_field": ("INT", { + "default": 0, + "min": 0, #Minimum value + "max": 4096, #Maximum value + "step": 64, #Slider's step + "display": "number" # Cosmetic only: display as "number" or "slider" + }), + "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}), + "print_to_screen": (["enable", "disable"],), + "string_field": ("STRING", { + "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node + "default": "dong!" + }), + }, + } + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "test" + + CATEGORY = "inflamously" + + def test(self, clip, image, string_field, int_field, float_field, print_to_screen): + if print_to_screen == "enable": + print(f"""Your input contains: + string_field aka input text: {string_field} + int_field: {int_field} + float_field: {float_field} + """) + #do some processing on the image, in this example I just invert it + image = 0.5 - image + tokens = clip.tokenize("test") + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled}]], ) + +NODE_CLASS_MAPPINGS = { + "TestNode": TestNode, + "TestNode2": TestNode, +} \ No newline at end of file diff --git a/custom_nodes/test_sampler.py b/custom_nodes/test_sampler.py new file mode 100644 index 000000000..a146bc9ba --- /dev/null +++ b/custom_nodes/test_sampler.py @@ -0,0 +1,50 @@ +import latent_preview +from custom_nodes.debug_model import DebugModel +from nodes import common_ksampler + + +class TestSampler: + SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] + SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", + "uni_pc_bh2"] + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (TestSampler.SAMPLERS,), + "scheduler": (TestSampler.SCHEDULERS,), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "latent_image": ("LATENT",), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "mixture": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("LATENT", "LATENT", "LATENT") + FUNCTION = "sample" + + CATEGORY = "inflamously" + + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, + mixture=1.0): + a_val = common_ksampler(model, seed, round(steps / 2), cfg, sampler_name, scheduler, positive, negative, + latent_image, denoise=.8) + b_val = common_ksampler(model, seed + 1, round(steps / 2), cfg, sampler_name, scheduler, positive, negative, + a_val[0], denoise=.9) + x_val = common_ksampler(model, seed + 2, round(steps), cfg, sampler_name, scheduler, positive, negative, b_val[0], denoise=denoise) + return (x_val[0], a_val[0], b_val[0]) + +# c_val = [{"samples": None}] +# c_val[0]["samples"] = (a_val[0]["samples"] * 0.5 * (1.0 - mixture)) + (b_val[0]["samples"] * 0.5 * (0.0 + mixture)) +# c_val[0]["samples"] = (a_val[0]["samples"] * (1.0 - mixture)) - (b_val[0]["samples"] * (0.0 + mixture)) + +NODE_CLASS_MAPPINGS = { + "TestSampler": TestSampler +} diff --git a/execution.py b/execution.py index 5f5d6c738..f41d28b07 100644 --- a/execution.py +++ b/execution.py @@ -1,4 +1,5 @@ import os +import queue import sys import copy import json @@ -11,6 +12,8 @@ import torch import nodes import comfy.model_management +from message_queue import PromptExecutorMessageQueue + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -41,6 +44,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all +# TODO: Called to execute Node's function def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists input_is_list = False @@ -72,6 +76,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() + # TODO: Executes impl node or custom_nodes function results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results @@ -117,6 +122,7 @@ def format_value(x): else: return str(x) +# TODO: Retrieves Node Input Data to be passed onto def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -152,6 +158,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: + success, error = validate_output_ui_data(server, unique_id, prompt_id, class_type, executed, output_ui) + if not success: + raise Exception("Output UI Error: {}".format(error)) + if "UI_TEMPLATE" in output_ui: + template_file = os.path.join(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)))), "custom_nodes", output_ui["UI_TEMPLATE"][0]) + with open(template_file, "r") as f: + output_ui["UI_TEMPLATE"] = f.read() + if len(output_ui["UI_TEMPLATE"]) <= 0: raise Exception("UI_TEMPLATE cannot be empty!") outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) @@ -194,6 +208,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute return (True, None, None) + +def validate_output_ui_data(server, node_id, prompt_id, class_type, executed, output_ui): + try: + json.dumps(output_ui) + return True, None + except Exception as error: + return False, error + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -230,7 +252,9 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item #is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed - except: + except Exception as e: + # TODO: IMPL Frontend UI + print("Exception occured on IS_CHANGED: {}".format(e)) to_delete = True else: is_changed = prompt[unique_id]['is_changed'] @@ -267,6 +291,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + # TODO: Caches node instances self.object_storage = {} self.outputs_ui = {} self.old_prompt = {} @@ -382,6 +407,23 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None + def prompt_message_loop(self): + # TODO: Better refactor, is it good here? + try: + while PromptExecutorMessageQueue.get_prompt_queue().not_empty: + msg, data = PromptExecutorMessageQueue.get_prompt_queue().get(False) + if msg: + if msg == "NODE_REFRESH": + for refreshed_node_list in data: + for refreshed_node in refreshed_node_list: + keys = self.object_storage.keys() + for nodeKey in keys: + if nodeKey[1] == refreshed_node["name"]: + self.object_storage.pop(nodeKey) + break + print("PROMPT_EXECUTOR_MESSAGE_EVENT: {}".format(msg)) + except queue.Empty: + pass # Just ignore def validate_inputs(prompt, item, validated): diff --git a/main.bat b/main.bat new file mode 100644 index 000000000..ef53af9d4 --- /dev/null +++ b/main.bat @@ -0,0 +1,20 @@ +@echo off + +:: Deactivate the virtual environment +call .\venv\Scripts\deactivate.bat + +:: Activate the virtual environment +call .\venv\Scripts\activate.bat +set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib + +:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments +if %errorlevel% equ 0 ( + REM Check if the batch was started via double-click + IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" ( + REM echo This script was started by double clicking. + cmd /k python.exe main.py --auto-launch %* + ) ELSE ( + REM echo This script was started from a command prompt. + python.exe main.py --auto-launch %* + ) +) \ No newline at end of file diff --git a/main.py b/main.py index 7c5eaee0a..3eef7d2f0 100644 --- a/main.py +++ b/main.py @@ -86,11 +86,13 @@ def cuda_malloc_warning(): if cuda_malloc_warning: print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") +# TODO: Prompt handler of each node recursively def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() execution_start_time = time.perf_counter() + e.prompt_message_loop() prompt_id = item[1] e.execute(item[2], prompt_id, item[3], item[4]) q.task_done(item_id, e.outputs_ui) diff --git a/message_queue.py b/message_queue.py new file mode 100644 index 000000000..471b03adf --- /dev/null +++ b/message_queue.py @@ -0,0 +1,11 @@ +import queue + + +# This queue is loop-driven by second created thread that processes additional prompt messages + +class PromptExecutorMessageQueue: + __PROMPT_QUEUE = queue.LifoQueue() + + @staticmethod + def get_prompt_queue(): + return PromptExecutorMessageQueue.__PROMPT_QUEUE diff --git a/nodes.py b/nodes.py index 77d180526..45d2814f8 100644 --- a/nodes.py +++ b/nodes.py @@ -14,6 +14,8 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch +from message_queue import PromptExecutorMessageQueue + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -933,7 +935,7 @@ class LatentFromBatch: else: s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] return (s,) - + class RepeatLatentBatch: @classmethod def INPUT_TYPES(s): @@ -948,7 +950,7 @@ class RepeatLatentBatch: def repeat(self, samples, amount): s = samples.copy() s_in = samples["samples"] - + s["samples"] = s_in.repeat((amount, 1,1,1)) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] @@ -1277,7 +1279,7 @@ class SaveImage: @classmethod def INPUT_TYPES(s): - return {"required": + return {"required": {"images": ("IMAGE", ), "filename_prefix": ("STRING", {"default": "ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, @@ -1707,8 +1709,22 @@ NODE_DISPLAY_NAME_MAPPINGS = { EXTENSION_WEB_DIRS = {} + +class CustomNodeData: + def __init__(self, name="", reloaded=False): + self.name = name + self.reloaded = reloaded + + def dict(self): + return self.__dict__ + + +# TODO: Validate custom node since it throws bad errors. def load_custom_node(module_path, ignore=set()): module_name = os.path.basename(module_path) + module_reload = False + loaded_custom_node_data = [] + if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] @@ -1720,6 +1736,10 @@ def load_custom_node(module_path, ignore=set()): module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) module_dir = module_path + if module_name in sys.modules: + print("Module reload: {}".format(module_name)) + module_reload = True + module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) @@ -1731,14 +1751,18 @@ def load_custom_node(module_path, ignore=set()): if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: for name in module.NODE_CLASS_MAPPINGS: - if name not in ignore: + if module_reload or name not in ignore: NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name] + # TODO: Allow multiple params for node without overwriting + loaded_custom_node_data.append( + CustomNodeData(name, module_reload).dict() + ) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) - return True + return True, loaded_custom_node_data else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") - return False + return False, loaded_custom_node_data except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) @@ -1748,29 +1772,39 @@ def load_custom_nodes(): base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] + node_data = {} for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) - if "__pycache__" in possible_modules: - possible_modules.remove("__pycache__") - for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if os.path.basename(module_path).startswith("__") or os.path.splitext(module_path)[1] != ".py" or not os.path.isfile(module_path): + print("Invalid module found: {}".format(possible_module)) + continue if module_path.endswith(".disabled"): continue time_before = time.perf_counter() - success = load_custom_node(module_path, base_node_names) + success, custom_node_data = load_custom_node(module_path, base_node_names) + if success: + node_data[module_path] = custom_node_data node_import_times.append((time.perf_counter() - time_before, module_path, success)) + print("Custom Loaded Nodes Data: {}".format(node_data)) + if len(node_import_times) > 0: print("\nImport times for custom nodes:") for n in sorted(node_import_times): if n[2]: - import_message = "" + import_message = " (SUCCESS)" else: import_message = " (IMPORT FAILED)" print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() + # Notify other prompt loop thread of refresh + refreshed_nodes_list = [custom_node for custom_node in [custom_node_list for _, custom_node_list in node_data.items()]] + PromptExecutorMessageQueue.get_prompt_queue().put(["NODE_REFRESH", refreshed_nodes_list]) + + return node_data + def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) @@ -1781,4 +1815,6 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) + # TODO: How to load without pushing this complete addon + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes/comfyui_controlnet_aux"), "__init__.py")) load_custom_nodes() diff --git a/requirements.txt b/requirements.txt index 14524485a..80f8c2e89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch torchsde einops -transformers>=4.25.1 +transformers==4.26.1 safetensors>=0.3.0 aiohttp accelerate @@ -10,3 +10,4 @@ Pillow scipy tqdm psutil +clip-interrogator==0.5.4 \ No newline at end of file diff --git a/server.py b/server.py index d04060499..14ada7c9f 100644 --- a/server.py +++ b/server.py @@ -30,7 +30,6 @@ from comfy.cli_args import args import comfy.utils import comfy.model_management - class BinaryEventTypes: PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 @@ -424,6 +423,10 @@ class PromptServer(): out[node_class] = node_info(node_class) return web.json_response(out) + @routes.get("/custom_nodes") + async def get_load_nodes(request): + return web.json_response(nodes.load_custom_nodes()) + @routes.get("/history") async def get_history(request): return web.json_response(self.prompt_queue.get_history()) @@ -446,7 +449,7 @@ class PromptServer(): print("got prompt") resp_code = 200 out_string = "" - json_data = await request.json() + json_data = await request.json() json_data = self.trigger_on_prompt(json_data) if "number" in json_data: diff --git a/web/scripts/api.js b/web/scripts/api.js index b1d245d73..f6ad2fd12 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -181,6 +181,11 @@ class ComfyApi extends EventTarget { return await resp.json(); } + async getCustomNodes() { + const resp = await this.fetchApi("/custom_nodes", { cache: "no-store" }); + return await resp.json() + } + /** * * @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue diff --git a/web/scripts/app.js b/web/scripts/app.js index 6dd1f3edd..2cb17d8c8 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1055,13 +1055,15 @@ export class ComfyApp { this.graph.setDirtyCanvas(true, false); delete this.nodePreviewImages[this.runningNodeId] }); - + // TODO: UI Update api.addEventListener("executed", ({ detail }) => { this.nodeOutputs[detail.node] = detail.output; const node = this.graph.getNodeById(detail.node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); + + this.updateNode(node, detail); } }); @@ -1181,8 +1183,8 @@ export class ComfyApp { this.loadGraphData(); } - // Save current workflow automatically - setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000); + // Save current workflow automatically + setInterval(async () => await this.saveWorkflow(), 1000); this.#addDrawNodeHandler(); this.#addDrawGroupsHandler(); @@ -1195,6 +1197,9 @@ export class ComfyApp { await this.#invokeExtensionsAsync("setup"); } + async saveWorkflow() { + localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())); + } /** * Registers nodes with the graph */ @@ -1646,11 +1651,23 @@ export class ComfyApp { this.extensions.push(extension); } - /** - * Refresh combo list on whole nodes - */ - async refreshComboInNodes() { + /** + * Refresh combo list on whole nodes + * @param {HTMLElement} button + */ + async refreshComboInNodes(button) { + if (button.getAttribute("disabled")) { + // Do not allow multiple refreshes + return; + } + button.setAttribute("disabled", true); + // Reload custom node modules under custom_nodes + const customNodeData = await api.getCustomNodes(); + + // Reload combobox const defs = await api.getNodeDefs(); + LiteGraph.clearRegisteredTypes(); + await this.registerNodesFromDefs(defs); for(let nodeNum in this.graph._nodes) { const node = this.graph._nodes[nodeNum]; @@ -1674,6 +1691,8 @@ export class ComfyApp { } } } + + button.removeAttribute("disabled"); } /** @@ -1686,6 +1705,25 @@ export class ComfyApp { this.lastExecutionError = null; this.runningNodeId = null; } + + /** + * Update Node UI Based on node state data + * TODO: Better Idea than just plain impl into App? + */ + updateNode(node, detail) { + switch (node.type) { + case "DebugNode": + const {texts} = detail.output + if (texts !== undefined && texts.length > 0) { + node.title = texts[0].substring(0, 16); + node.widgets[0].value = texts[0] + } + break; + case "DebugCond": + console.log(detail) + break; + } + } } export const app = new ComfyApp(); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f39939bf3..aa6d5b499 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -751,7 +751,7 @@ export class ComfyUI { $el("button", { id: "comfy-refresh-button", textContent: "Refresh", - onclick: () => app.refreshComboInNodes() + onclick: () => app.refreshComboInNodes(document.getElementById("comfy-refresh-button")) }), $el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}), $el("button", {