mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 14:02:37 +08:00
feat: add image cropping
feat: remove custom js stuff since it just doesnt work well feat: add cond debugging feat: impls PoC refresh of custom_nodes + custom_node extensions feat: ignore workflows folder feat: add batch file to start application under windows feat: integrate reload custom node into refresh feat: update custom node ui feat: impl node change event handling !WIP! feat: add CustomNodeData class for reuse feat: remove all reloaded nodes for test purposes and save graph afterwards !WIP! feat: remove unused registeredNodes feat: comment out graph removal feat: comment on some functions for proper understanding and bookmarking (for now) feat: comment node execution location feat: add exception for IS_CHANGED issues feat: extend example_node README !WIP! feat: custom test nodes for now !WIP! feat: avoid refresh spam feat: add debug_cond custom_node with WIP ui feat: add hint for validating output_ui data feat: pass refresh button into combo function feat: impl output ui error feat: auto refresh nodes fix: various minor issues !WIP! feat: barebone JS scripting in BE for ui templating !WIP! feat: impl interrogation with clip feat: impl more debug samplers feat: change requirements.txt for transformers fix: __init__.py issues when importing custom_nodes feat: temp ignore 3rdparty code feat: add custom_nodes debug_latent and image_fx
This commit is contained in:
parent
0966d3ce82
commit
98155446bf
6
.gitignore
vendored
6
.gitignore
vendored
@ -6,6 +6,10 @@ __pycache__/
|
||||
/models/
|
||||
/temp/
|
||||
/custom_nodes/
|
||||
cache
|
||||
models/
|
||||
temp/
|
||||
/custom_nodes/.pytest_cache
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
@ -14,3 +18,5 @@ venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
!/web/extensions/core/
|
||||
/workflows
|
||||
**/comfyui_controlnet_aux
|
||||
|
||||
58
custom_nodes/clip_interrogator.py
Normal file
58
custom_nodes/clip_interrogator.py
Normal file
@ -0,0 +1,58 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
from clip_interrogator import Interrogator, Config
|
||||
from torch import Tensor
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
class ClipInterrogator:
|
||||
MODEL_NAME = ["ViT-L-14/openai"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"clip": ("CLIP",),
|
||||
"model_name": (ClipInterrogator.MODEL_NAME,),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "clip_interrogate"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
VALUE = ""
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, clip, model_name):
|
||||
# TODO: Why does this not cache immidiately
|
||||
return hashlib.md5(str(bytearray(image.numpy())).encode("utf-8")).hexdigest()
|
||||
|
||||
def clip_interrogate(self, image, clip, model_name):
|
||||
img_tensor = image[0]
|
||||
# define a transform to convert a tensor to PIL image
|
||||
transform = T.ToPILImage()
|
||||
h, w, c = img_tensor.size()
|
||||
# print(h,w,c)
|
||||
# convert the tensor to PIL image using above transform
|
||||
img = transform(image[0].reshape(c, h, w)) # Reshape since Tensor is using Height, Width, Color but Image needs C, H, W
|
||||
config = Config(clip_model_name=model_name)
|
||||
config.apply_low_vram_defaults()
|
||||
ci = Interrogator(config)
|
||||
ClipInterrogator.VALUE = ci.interrogate(img)
|
||||
print("Image:", ClipInterrogator.VALUE)
|
||||
tokens = clip.tokenize(ClipInterrogator.VALUE)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ClipInterrogator": ClipInterrogator
|
||||
}
|
||||
49
custom_nodes/debug_cond.py
Normal file
49
custom_nodes/debug_cond.py
Normal file
@ -0,0 +1,49 @@
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import einops
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import matplotlib.pyplot as plt
|
||||
import torchvision.transforms as T
|
||||
|
||||
class DebugCond:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",),
|
||||
"cond_input": ("CONDITIONING",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "IMAGE",)
|
||||
FUNCTION = "debug_node"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, clip, cond_input):
|
||||
# TODO: Why does this not cache immidiately
|
||||
return random.randint(0, 10000)
|
||||
|
||||
def debug_node(self, clip, cond_input):
|
||||
# print("Cond Shape:", cond_input[0][0].shape)
|
||||
# signal = cond_input[0][0].reshape(-1)
|
||||
# stripped_signal = signal[::2048]
|
||||
plt.plot(cond_input[0][0][0])
|
||||
img = PIL.Image.frombytes('RGB', plt.gcf().canvas.get_width_height(), plt.gcf().canvas.tostring_rgb())
|
||||
img_tensor = T.PILToTensor()(img) / 255.0
|
||||
img_tensor = einops.reduce(img_tensor, "a b c -> 1 b c a", "max")
|
||||
return cond_input, img_tensor
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DebugCond": DebugCond
|
||||
}
|
||||
|
||||
# TODO: Impl into execution.py
|
||||
SCRIPT_TEMPLATE_PATH = os.path.join(os.path.join(__file__, os.pardir), "debug_cond.js")
|
||||
34
custom_nodes/debug_latent.py
Normal file
34
custom_nodes/debug_latent.py
Normal file
@ -0,0 +1,34 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class DebugLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"latent": ("LATENT",), }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT", "LATENT",)
|
||||
FUNCTION = "latent_space"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def latent_space(self, latent):
|
||||
x = latent["samples"]
|
||||
transformer = T.ToPILImage()
|
||||
img: Image = transformer(x[0])
|
||||
# img.show()
|
||||
# y = x * 0.75 - x * 0.25 + torch.rand(x.shape) * 0.1
|
||||
y = x * 0.5 + torch.rand(x.shape) * 0.5
|
||||
modified_latent = {"samples": y}
|
||||
return (latent, modified_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DebugLatent": DebugLatent
|
||||
}
|
||||
28
custom_nodes/debug_model.py
Normal file
28
custom_nodes/debug_model.py
Normal file
@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DebugModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_input": ("MODEL",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "debug_node"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def debug_node(self, model_input):
|
||||
print("Model:", model_input)
|
||||
return {}
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DebugModel": DebugModel
|
||||
}
|
||||
23
custom_nodes/debug_node.py
Normal file
23
custom_nodes/debug_node.py
Normal file
@ -0,0 +1,23 @@
|
||||
class DebugNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"cond_input": ("CONDITIONING",),
|
||||
"text": ("STRING", { "default": "" }),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "debug_node"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def debug_node(self, cond_input, text):
|
||||
return { "ui": { "texts": ["ABC"] } }
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DebugNode": DebugNode
|
||||
}
|
||||
@ -6,6 +6,8 @@ class Example:
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED (dict) -> str:
|
||||
Tells the prompt loop if the current node has change on new execution based on a string identifier
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@ -37,7 +39,8 @@ class Example:
|
||||
The type can be a list for selection.
|
||||
|
||||
Returns: `dict`:
|
||||
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
|
||||
- Key input_fields_group (`string`): Can be either required, hidden or optional.
|
||||
- A node class must have property `required`
|
||||
- Value input_fields (`dict`): Contains input fields config:
|
||||
* Key field_name (`string`): Name of a entry-point method's argument
|
||||
* Value field_config (`tuple`):
|
||||
|
||||
63
custom_nodes/image_crop.py
Normal file
63
custom_nodes/image_crop.py
Normal file
@ -0,0 +1,63 @@
|
||||
import math
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import ImageFilter
|
||||
from PIL.Image import Image
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
class ImageCrop:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
"center_x": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, # Minimum value
|
||||
"max": 4096, # Maximum value
|
||||
"step": 16, # Slider's step
|
||||
}),
|
||||
"center_y": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, # Minimum value
|
||||
"max": 4096, # Maximum value
|
||||
"step": 16, # Slider's step
|
||||
}),
|
||||
"pixelradius": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, # Minimum value
|
||||
"max": 4096, # Maximum value
|
||||
"step": 16, # Slider's step
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT", "IMAGE",)
|
||||
|
||||
FUNCTION = "image_crop"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def image_crop(self, vae, latent, center_x, center_y, pixelradius):
|
||||
tensor_img = vae.decode(latent["samples"])
|
||||
stripped_tensor_img = tensor_img[0]
|
||||
h, w, c = stripped_tensor_img.size()
|
||||
pil_img: Image = T.ToPILImage()(einops.rearrange(stripped_tensor_img, "h w c -> c h w"))
|
||||
nw, nh = center_x + pixelradius / 2, center_y + pixelradius / 2
|
||||
pil_img = pil_img.crop((center_x - pixelradius / 2, center_y - pixelradius / 2, nw, nh))
|
||||
new_tensor_img = einops.reduce(T.ToTensor()(pil_img), "c h w -> 1 h w c", "max")
|
||||
# new_tensor_img = new_stripped_tensor_img.permute(0, 1, 2, 3)
|
||||
pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img)
|
||||
new_latent = vae.encode(pixels[:, :, :, :3])
|
||||
return ({"samples": new_latent}, new_tensor_img)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop
|
||||
}
|
||||
43
custom_nodes/image_fx.py
Normal file
43
custom_nodes/image_fx.py
Normal file
@ -0,0 +1,43 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import ImageFilter
|
||||
from PIL.Image import Image
|
||||
|
||||
import nodes
|
||||
|
||||
|
||||
class ImageFX:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{
|
||||
"vae": ("VAE",),
|
||||
"latent": ("LATENT",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT", "IMAGE",)
|
||||
|
||||
FUNCTION = "image_fx"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def image_fx(self, vae, latent):
|
||||
tensor_img = vae.decode(latent["samples"])
|
||||
stripped_tensor_img = tensor_img[0]
|
||||
h, w, c = stripped_tensor_img.size()
|
||||
pil_img: Image = T.ToPILImage()(stripped_tensor_img.reshape(c, h, w))
|
||||
pil_img = pil_img.filter(ImageFilter.ModeFilter(2))
|
||||
new_stripped_tensor_img = T.PILToTensor()(pil_img) / 255.0
|
||||
new_tensor_img = new_stripped_tensor_img.reshape(1, h, w, c)
|
||||
pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img)
|
||||
new_latent = vae.encode(pixels[:, :, :, :3])
|
||||
return ({"samples": new_latent}, new_tensor_img)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageFX": ImageFX
|
||||
}
|
||||
42
custom_nodes/test_generator.py
Normal file
42
custom_nodes/test_generator.py
Normal file
@ -0,0 +1,42 @@
|
||||
import random
|
||||
|
||||
|
||||
class TestGenerator:
|
||||
|
||||
def __init__(self):
|
||||
self.testID = 0
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP",),
|
||||
},
|
||||
"hidden": {
|
||||
"testId": ("STRING",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "test_generator"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
TESTID = 0
|
||||
@classmethod
|
||||
def IS_CHANGED(s, clip, testId=None):
|
||||
# intValue = random.randint(0, 100)
|
||||
# value = str(intValue)
|
||||
if TestGenerator.TESTID < 2:
|
||||
TestGenerator.TESTID += 1
|
||||
return str(TestGenerator.TESTID)
|
||||
|
||||
def test_generator(self, clip, testId=None):
|
||||
tokens = clip.tokenize("test")
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TestGenerator": TestGenerator
|
||||
}
|
||||
48
custom_nodes/test_node.py
Normal file
48
custom_nodes/test_node.py
Normal file
@ -0,0 +1,48 @@
|
||||
from transformers.models import clip
|
||||
|
||||
|
||||
class TestNode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"image": ("IMAGE",),
|
||||
"int_field": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, #Minimum value
|
||||
"max": 4096, #Maximum value
|
||||
"step": 64, #Slider's step
|
||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||
}),
|
||||
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}),
|
||||
"print_to_screen": (["enable", "disable"],),
|
||||
"string_field": ("STRING", {
|
||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||
"default": "dong!"
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "test"
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def test(self, clip, image, string_field, int_field, float_field, print_to_screen):
|
||||
if print_to_screen == "enable":
|
||||
print(f"""Your input contains:
|
||||
string_field aka input text: {string_field}
|
||||
int_field: {int_field}
|
||||
float_field: {float_field}
|
||||
""")
|
||||
#do some processing on the image, in this example I just invert it
|
||||
image = 0.5 - image
|
||||
tokens = clip.tokenize("test")
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return ([[cond, {"pooled_output": pooled}]], )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TestNode": TestNode,
|
||||
"TestNode2": TestNode,
|
||||
}
|
||||
50
custom_nodes/test_sampler.py
Normal file
50
custom_nodes/test_sampler.py
Normal file
@ -0,0 +1,50 @@
|
||||
import latent_preview
|
||||
from custom_nodes.debug_model import DebugModel
|
||||
from nodes import common_ksampler
|
||||
|
||||
|
||||
class TestSampler:
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc",
|
||||
"uni_pc_bh2"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||
"sampler_name": (TestSampler.SAMPLERS,),
|
||||
"scheduler": (TestSampler.SCHEDULERS,),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"latent_image": ("LATENT",),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"mixture": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT", "LATENT", "LATENT")
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "inflamously"
|
||||
|
||||
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
|
||||
mixture=1.0):
|
||||
a_val = common_ksampler(model, seed, round(steps / 2), cfg, sampler_name, scheduler, positive, negative,
|
||||
latent_image, denoise=.8)
|
||||
b_val = common_ksampler(model, seed + 1, round(steps / 2), cfg, sampler_name, scheduler, positive, negative,
|
||||
a_val[0], denoise=.9)
|
||||
x_val = common_ksampler(model, seed + 2, round(steps), cfg, sampler_name, scheduler, positive, negative, b_val[0], denoise=denoise)
|
||||
return (x_val[0], a_val[0], b_val[0])
|
||||
|
||||
# c_val = [{"samples": None}]
|
||||
# c_val[0]["samples"] = (a_val[0]["samples"] * 0.5 * (1.0 - mixture)) + (b_val[0]["samples"] * 0.5 * (0.0 + mixture))
|
||||
# c_val[0]["samples"] = (a_val[0]["samples"] * (1.0 - mixture)) - (b_val[0]["samples"] * (0.0 + mixture))
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TestSampler": TestSampler
|
||||
}
|
||||
44
execution.py
44
execution.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import copy
|
||||
import json
|
||||
@ -11,6 +12,8 @@ import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
from message_queue import PromptExecutorMessageQueue
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
@ -41,6 +44,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
# TODO: Called to execute Node's function
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
input_is_list = False
|
||||
@ -72,6 +76,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
# TODO: Executes impl node or custom_nodes function
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
@ -117,6 +122,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
# TODO: Retrieves Node Input Data to be passed onto
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -152,6 +158,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
success, error = validate_output_ui_data(server, unique_id, prompt_id, class_type, executed, output_ui)
|
||||
if not success:
|
||||
raise Exception("Output UI Error: {}".format(error))
|
||||
if "UI_TEMPLATE" in output_ui:
|
||||
template_file = os.path.join(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)))), "custom_nodes", output_ui["UI_TEMPLATE"][0])
|
||||
with open(template_file, "r") as f:
|
||||
output_ui["UI_TEMPLATE"] = f.read()
|
||||
if len(output_ui["UI_TEMPLATE"]) <= 0: raise Exception("UI_TEMPLATE cannot be empty!")
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
@ -194,6 +208,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
|
||||
def validate_output_ui_data(server, node_id, prompt_id, class_type, executed, output_ui):
|
||||
try:
|
||||
json.dumps(output_ui)
|
||||
return True, None
|
||||
except Exception as error:
|
||||
return False, error
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -230,7 +252,9 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
except Exception as e:
|
||||
# TODO: IMPL Frontend UI
|
||||
print("Exception occured on IS_CHANGED: {}".format(e))
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
@ -267,6 +291,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
# TODO: Caches node instances
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
@ -382,6 +407,23 @@ class PromptExecutor:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
|
||||
def prompt_message_loop(self):
|
||||
# TODO: Better refactor, is it good here?
|
||||
try:
|
||||
while PromptExecutorMessageQueue.get_prompt_queue().not_empty:
|
||||
msg, data = PromptExecutorMessageQueue.get_prompt_queue().get(False)
|
||||
if msg:
|
||||
if msg == "NODE_REFRESH":
|
||||
for refreshed_node_list in data:
|
||||
for refreshed_node in refreshed_node_list:
|
||||
keys = self.object_storage.keys()
|
||||
for nodeKey in keys:
|
||||
if nodeKey[1] == refreshed_node["name"]:
|
||||
self.object_storage.pop(nodeKey)
|
||||
break
|
||||
print("PROMPT_EXECUTOR_MESSAGE_EVENT: {}".format(msg))
|
||||
except queue.Empty:
|
||||
pass # Just ignore
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
|
||||
20
main.bat
Normal file
20
main.bat
Normal file
@ -0,0 +1,20 @@
|
||||
@echo off
|
||||
|
||||
:: Deactivate the virtual environment
|
||||
call .\venv\Scripts\deactivate.bat
|
||||
|
||||
:: Activate the virtual environment
|
||||
call .\venv\Scripts\activate.bat
|
||||
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
||||
|
||||
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
||||
if %errorlevel% equ 0 (
|
||||
REM Check if the batch was started via double-click
|
||||
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
|
||||
REM echo This script was started by double clicking.
|
||||
cmd /k python.exe main.py --auto-launch %*
|
||||
) ELSE (
|
||||
REM echo This script was started from a command prompt.
|
||||
python.exe main.py --auto-launch %*
|
||||
)
|
||||
)
|
||||
2
main.py
2
main.py
@ -86,11 +86,13 @@ def cuda_malloc_warning():
|
||||
if cuda_malloc_warning:
|
||||
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
# TODO: Prompt handler of each node recursively
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
execution_start_time = time.perf_counter()
|
||||
e.prompt_message_loop()
|
||||
prompt_id = item[1]
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
|
||||
11
message_queue.py
Normal file
11
message_queue.py
Normal file
@ -0,0 +1,11 @@
|
||||
import queue
|
||||
|
||||
|
||||
# This queue is loop-driven by second created thread that processes additional prompt messages
|
||||
|
||||
class PromptExecutorMessageQueue:
|
||||
__PROMPT_QUEUE = queue.LifoQueue()
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_queue():
|
||||
return PromptExecutorMessageQueue.__PROMPT_QUEUE
|
||||
60
nodes.py
60
nodes.py
@ -14,6 +14,8 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
from message_queue import PromptExecutorMessageQueue
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
@ -933,7 +935,7 @@ class LatentFromBatch:
|
||||
else:
|
||||
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
|
||||
return (s,)
|
||||
|
||||
|
||||
class RepeatLatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -948,7 +950,7 @@ class RepeatLatentBatch:
|
||||
def repeat(self, samples, amount):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
|
||||
|
||||
s["samples"] = s_in.repeat((amount, 1,1,1))
|
||||
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
|
||||
masks = samples["noise_mask"]
|
||||
@ -1277,7 +1279,7 @@ class SaveImage:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
@ -1707,8 +1709,22 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
EXTENSION_WEB_DIRS = {}
|
||||
|
||||
|
||||
class CustomNodeData:
|
||||
def __init__(self, name="", reloaded=False):
|
||||
self.name = name
|
||||
self.reloaded = reloaded
|
||||
|
||||
def dict(self):
|
||||
return self.__dict__
|
||||
|
||||
|
||||
# TODO: Validate custom node since it throws bad errors.
|
||||
def load_custom_node(module_path, ignore=set()):
|
||||
module_name = os.path.basename(module_path)
|
||||
module_reload = False
|
||||
loaded_custom_node_data = []
|
||||
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
module_name = sp[0]
|
||||
@ -1720,6 +1736,10 @@ def load_custom_node(module_path, ignore=set()):
|
||||
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||
module_dir = module_path
|
||||
|
||||
if module_name in sys.modules:
|
||||
print("Module reload: {}".format(module_name))
|
||||
module_reload = True
|
||||
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
sys.modules[module_name] = module
|
||||
module_spec.loader.exec_module(module)
|
||||
@ -1731,14 +1751,18 @@ def load_custom_node(module_path, ignore=set()):
|
||||
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name in module.NODE_CLASS_MAPPINGS:
|
||||
if name not in ignore:
|
||||
if module_reload or name not in ignore:
|
||||
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
|
||||
# TODO: Allow multiple params for node without overwriting
|
||||
loaded_custom_node_data.append(
|
||||
CustomNodeData(name, module_reload).dict()
|
||||
)
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
return True, loaded_custom_node_data
|
||||
else:
|
||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
return False, loaded_custom_node_data
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||
@ -1748,29 +1772,39 @@ def load_custom_nodes():
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
node_data = {}
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if os.path.basename(module_path).startswith("__") or os.path.splitext(module_path)[1] != ".py" or not os.path.isfile(module_path):
|
||||
print("Invalid module found: {}".format(possible_module))
|
||||
continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names)
|
||||
success, custom_node_data = load_custom_node(module_path, base_node_names)
|
||||
if success:
|
||||
node_data[module_path] = custom_node_data
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
print("Custom Loaded Nodes Data: {}".format(node_data))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
print("\nImport times for custom nodes:")
|
||||
for n in sorted(node_import_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
import_message = " (SUCCESS)"
|
||||
else:
|
||||
import_message = " (IMPORT FAILED)"
|
||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
print()
|
||||
|
||||
# Notify other prompt loop thread of refresh
|
||||
refreshed_nodes_list = [custom_node for custom_node in [custom_node_list for _, custom_node_list in node_data.items()]]
|
||||
PromptExecutorMessageQueue.get_prompt_queue().put(["NODE_REFRESH", refreshed_nodes_list])
|
||||
|
||||
return node_data
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
@ -1781,4 +1815,6 @@ def init_custom_nodes():
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
||||
# TODO: How to load without pushing this complete addon
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes/comfyui_controlnet_aux"), "__init__.py"))
|
||||
load_custom_nodes()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchsde
|
||||
einops
|
||||
transformers>=4.25.1
|
||||
transformers==4.26.1
|
||||
safetensors>=0.3.0
|
||||
aiohttp
|
||||
accelerate
|
||||
@ -10,3 +10,4 @@ Pillow
|
||||
scipy
|
||||
tqdm
|
||||
psutil
|
||||
clip-interrogator==0.5.4
|
||||
@ -30,7 +30,6 @@ from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
@ -424,6 +423,10 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/custom_nodes")
|
||||
async def get_load_nodes(request):
|
||||
return web.json_response(nodes.load_custom_nodes())
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
return web.json_response(self.prompt_queue.get_history())
|
||||
@ -446,7 +449,7 @@ class PromptServer():
|
||||
print("got prompt")
|
||||
resp_code = 200
|
||||
out_string = ""
|
||||
json_data = await request.json()
|
||||
json_data = await request.json()
|
||||
json_data = self.trigger_on_prompt(json_data)
|
||||
|
||||
if "number" in json_data:
|
||||
|
||||
@ -181,6 +181,11 @@ class ComfyApi extends EventTarget {
|
||||
return await resp.json();
|
||||
}
|
||||
|
||||
async getCustomNodes() {
|
||||
const resp = await this.fetchApi("/custom_nodes", { cache: "no-store" });
|
||||
return await resp.json()
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue
|
||||
|
||||
@ -1055,13 +1055,15 @@ export class ComfyApp {
|
||||
this.graph.setDirtyCanvas(true, false);
|
||||
delete this.nodePreviewImages[this.runningNodeId]
|
||||
});
|
||||
|
||||
// TODO: UI Update
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
|
||||
this.updateNode(node, detail);
|
||||
}
|
||||
});
|
||||
|
||||
@ -1181,8 +1183,8 @@ export class ComfyApp {
|
||||
this.loadGraphData();
|
||||
}
|
||||
|
||||
// Save current workflow automatically
|
||||
setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000);
|
||||
// Save current workflow automatically
|
||||
setInterval(async () => await this.saveWorkflow(), 1000);
|
||||
|
||||
this.#addDrawNodeHandler();
|
||||
this.#addDrawGroupsHandler();
|
||||
@ -1195,6 +1197,9 @@ export class ComfyApp {
|
||||
await this.#invokeExtensionsAsync("setup");
|
||||
}
|
||||
|
||||
async saveWorkflow() {
|
||||
localStorage.setItem("workflow", JSON.stringify(this.graph.serialize()));
|
||||
}
|
||||
/**
|
||||
* Registers nodes with the graph
|
||||
*/
|
||||
@ -1646,11 +1651,23 @@ export class ComfyApp {
|
||||
this.extensions.push(extension);
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh combo list on whole nodes
|
||||
*/
|
||||
async refreshComboInNodes() {
|
||||
/**
|
||||
* Refresh combo list on whole nodes
|
||||
* @param {HTMLElement} button
|
||||
*/
|
||||
async refreshComboInNodes(button) {
|
||||
if (button.getAttribute("disabled")) {
|
||||
// Do not allow multiple refreshes
|
||||
return;
|
||||
}
|
||||
button.setAttribute("disabled", true);
|
||||
// Reload custom node modules under custom_nodes
|
||||
const customNodeData = await api.getCustomNodes();
|
||||
|
||||
// Reload combobox
|
||||
const defs = await api.getNodeDefs();
|
||||
LiteGraph.clearRegisteredTypes();
|
||||
await this.registerNodesFromDefs(defs);
|
||||
|
||||
for(let nodeNum in this.graph._nodes) {
|
||||
const node = this.graph._nodes[nodeNum];
|
||||
@ -1674,6 +1691,8 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
button.removeAttribute("disabled");
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1686,6 +1705,25 @@ export class ComfyApp {
|
||||
this.lastExecutionError = null;
|
||||
this.runningNodeId = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update Node UI Based on node state data
|
||||
* TODO: Better Idea than just plain impl into App?
|
||||
*/
|
||||
updateNode(node, detail) {
|
||||
switch (node.type) {
|
||||
case "DebugNode":
|
||||
const {texts} = detail.output
|
||||
if (texts !== undefined && texts.length > 0) {
|
||||
node.title = texts[0].substring(0, 16);
|
||||
node.widgets[0].value = texts[0]
|
||||
}
|
||||
break;
|
||||
case "DebugCond":
|
||||
console.log(detail)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const app = new ComfyApp();
|
||||
|
||||
@ -751,7 +751,7 @@ export class ComfyUI {
|
||||
$el("button", {
|
||||
id: "comfy-refresh-button",
|
||||
textContent: "Refresh",
|
||||
onclick: () => app.refreshComboInNodes()
|
||||
onclick: () => app.refreshComboInNodes(document.getElementById("comfy-refresh-button"))
|
||||
}),
|
||||
$el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}),
|
||||
$el("button", {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user