mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 15:02:37 +08:00
feat: add image cropping
feat: remove custom js stuff since it just doesnt work well feat: add cond debugging feat: impls PoC refresh of custom_nodes + custom_node extensions feat: ignore workflows folder feat: add batch file to start application under windows feat: integrate reload custom node into refresh feat: update custom node ui feat: impl node change event handling !WIP! feat: add CustomNodeData class for reuse feat: remove all reloaded nodes for test purposes and save graph afterwards !WIP! feat: remove unused registeredNodes feat: comment out graph removal feat: comment on some functions for proper understanding and bookmarking (for now) feat: comment node execution location feat: add exception for IS_CHANGED issues feat: extend example_node README !WIP! feat: custom test nodes for now !WIP! feat: avoid refresh spam feat: add debug_cond custom_node with WIP ui feat: add hint for validating output_ui data feat: pass refresh button into combo function feat: impl output ui error feat: auto refresh nodes fix: various minor issues !WIP! feat: barebone JS scripting in BE for ui templating !WIP! feat: impl interrogation with clip feat: impl more debug samplers feat: change requirements.txt for transformers fix: __init__.py issues when importing custom_nodes feat: temp ignore 3rdparty code feat: add custom_nodes debug_latent and image_fx
This commit is contained in:
parent
0966d3ce82
commit
98155446bf
6
.gitignore
vendored
6
.gitignore
vendored
@ -6,6 +6,10 @@ __pycache__/
|
|||||||
/models/
|
/models/
|
||||||
/temp/
|
/temp/
|
||||||
/custom_nodes/
|
/custom_nodes/
|
||||||
|
cache
|
||||||
|
models/
|
||||||
|
temp/
|
||||||
|
/custom_nodes/.pytest_cache
|
||||||
!custom_nodes/example_node.py.example
|
!custom_nodes/example_node.py.example
|
||||||
extra_model_paths.yaml
|
extra_model_paths.yaml
|
||||||
/.vs
|
/.vs
|
||||||
@ -14,3 +18,5 @@ venv/
|
|||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
/workflows
|
||||||
|
**/comfyui_controlnet_aux
|
||||||
|
|||||||
58
custom_nodes/clip_interrogator.py
Normal file
58
custom_nodes/clip_interrogator.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from clip_interrogator import Interrogator, Config
|
||||||
|
from torch import Tensor
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
class ClipInterrogator:
|
||||||
|
MODEL_NAME = ["ViT-L-14/openai"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"model_name": (ClipInterrogator.MODEL_NAME,),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "clip_interrogate"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
VALUE = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, image, clip, model_name):
|
||||||
|
# TODO: Why does this not cache immidiately
|
||||||
|
return hashlib.md5(str(bytearray(image.numpy())).encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def clip_interrogate(self, image, clip, model_name):
|
||||||
|
img_tensor = image[0]
|
||||||
|
# define a transform to convert a tensor to PIL image
|
||||||
|
transform = T.ToPILImage()
|
||||||
|
h, w, c = img_tensor.size()
|
||||||
|
# print(h,w,c)
|
||||||
|
# convert the tensor to PIL image using above transform
|
||||||
|
img = transform(image[0].reshape(c, h, w)) # Reshape since Tensor is using Height, Width, Color but Image needs C, H, W
|
||||||
|
config = Config(clip_model_name=model_name)
|
||||||
|
config.apply_low_vram_defaults()
|
||||||
|
ci = Interrogator(config)
|
||||||
|
ClipInterrogator.VALUE = ci.interrogate(img)
|
||||||
|
print("Image:", ClipInterrogator.VALUE)
|
||||||
|
tokens = clip.tokenize(ClipInterrogator.VALUE)
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ClipInterrogator": ClipInterrogator
|
||||||
|
}
|
||||||
49
custom_nodes/debug_cond.py
Normal file
49
custom_nodes/debug_cond.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
class DebugCond:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"cond_input": ("CONDITIONING",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "IMAGE",)
|
||||||
|
FUNCTION = "debug_node"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, clip, cond_input):
|
||||||
|
# TODO: Why does this not cache immidiately
|
||||||
|
return random.randint(0, 10000)
|
||||||
|
|
||||||
|
def debug_node(self, clip, cond_input):
|
||||||
|
# print("Cond Shape:", cond_input[0][0].shape)
|
||||||
|
# signal = cond_input[0][0].reshape(-1)
|
||||||
|
# stripped_signal = signal[::2048]
|
||||||
|
plt.plot(cond_input[0][0][0])
|
||||||
|
img = PIL.Image.frombytes('RGB', plt.gcf().canvas.get_width_height(), plt.gcf().canvas.tostring_rgb())
|
||||||
|
img_tensor = T.PILToTensor()(img) / 255.0
|
||||||
|
img_tensor = einops.reduce(img_tensor, "a b c -> 1 b c a", "max")
|
||||||
|
return cond_input, img_tensor
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DebugCond": DebugCond
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Impl into execution.py
|
||||||
|
SCRIPT_TEMPLATE_PATH = os.path.join(os.path.join(__file__, os.pardir), "debug_cond.js")
|
||||||
34
custom_nodes/debug_latent.py
Normal file
34
custom_nodes/debug_latent.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"latent": ("LATENT",), }
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", "LATENT",)
|
||||||
|
FUNCTION = "latent_space"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def latent_space(self, latent):
|
||||||
|
x = latent["samples"]
|
||||||
|
transformer = T.ToPILImage()
|
||||||
|
img: Image = transformer(x[0])
|
||||||
|
# img.show()
|
||||||
|
# y = x * 0.75 - x * 0.25 + torch.rand(x.shape) * 0.1
|
||||||
|
y = x * 0.5 + torch.rand(x.shape) * 0.5
|
||||||
|
modified_latent = {"samples": y}
|
||||||
|
return (latent, modified_latent)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DebugLatent": DebugLatent
|
||||||
|
}
|
||||||
28
custom_nodes/debug_model.py
Normal file
28
custom_nodes/debug_model.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class DebugModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_input": ("MODEL",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "debug_node"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def debug_node(self, model_input):
|
||||||
|
print("Model:", model_input)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DebugModel": DebugModel
|
||||||
|
}
|
||||||
23
custom_nodes/debug_node.py
Normal file
23
custom_nodes/debug_node.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
class DebugNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"cond_input": ("CONDITIONING",),
|
||||||
|
"text": ("STRING", { "default": "" }),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "debug_node"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def debug_node(self, cond_input, text):
|
||||||
|
return { "ui": { "texts": ["ABC"] } }
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DebugNode": DebugNode
|
||||||
|
}
|
||||||
@ -6,6 +6,8 @@ class Example:
|
|||||||
-------------
|
-------------
|
||||||
INPUT_TYPES (dict):
|
INPUT_TYPES (dict):
|
||||||
Tell the main program input parameters of nodes.
|
Tell the main program input parameters of nodes.
|
||||||
|
IS_CHANGED (dict) -> str:
|
||||||
|
Tells the prompt loop if the current node has change on new execution based on a string identifier
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
@ -37,7 +39,8 @@ class Example:
|
|||||||
The type can be a list for selection.
|
The type can be a list for selection.
|
||||||
|
|
||||||
Returns: `dict`:
|
Returns: `dict`:
|
||||||
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
|
- Key input_fields_group (`string`): Can be either required, hidden or optional.
|
||||||
|
- A node class must have property `required`
|
||||||
- Value input_fields (`dict`): Contains input fields config:
|
- Value input_fields (`dict`): Contains input fields config:
|
||||||
* Key field_name (`string`): Name of a entry-point method's argument
|
* Key field_name (`string`): Name of a entry-point method's argument
|
||||||
* Value field_config (`tuple`):
|
* Value field_config (`tuple`):
|
||||||
|
|||||||
63
custom_nodes/image_crop.py
Normal file
63
custom_nodes/image_crop.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import ImageFilter
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCrop:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
"center_x": ("INT", {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0, # Minimum value
|
||||||
|
"max": 4096, # Maximum value
|
||||||
|
"step": 16, # Slider's step
|
||||||
|
}),
|
||||||
|
"center_y": ("INT", {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0, # Minimum value
|
||||||
|
"max": 4096, # Maximum value
|
||||||
|
"step": 16, # Slider's step
|
||||||
|
}),
|
||||||
|
"pixelradius": ("INT", {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0, # Minimum value
|
||||||
|
"max": 4096, # Maximum value
|
||||||
|
"step": 16, # Slider's step
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", "IMAGE",)
|
||||||
|
|
||||||
|
FUNCTION = "image_crop"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def image_crop(self, vae, latent, center_x, center_y, pixelradius):
|
||||||
|
tensor_img = vae.decode(latent["samples"])
|
||||||
|
stripped_tensor_img = tensor_img[0]
|
||||||
|
h, w, c = stripped_tensor_img.size()
|
||||||
|
pil_img: Image = T.ToPILImage()(einops.rearrange(stripped_tensor_img, "h w c -> c h w"))
|
||||||
|
nw, nh = center_x + pixelradius / 2, center_y + pixelradius / 2
|
||||||
|
pil_img = pil_img.crop((center_x - pixelradius / 2, center_y - pixelradius / 2, nw, nh))
|
||||||
|
new_tensor_img = einops.reduce(T.ToTensor()(pil_img), "c h w -> 1 h w c", "max")
|
||||||
|
# new_tensor_img = new_stripped_tensor_img.permute(0, 1, 2, 3)
|
||||||
|
pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img)
|
||||||
|
new_latent = vae.encode(pixels[:, :, :, :3])
|
||||||
|
return ({"samples": new_latent}, new_tensor_img)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageCrop": ImageCrop
|
||||||
|
}
|
||||||
43
custom_nodes/image_fx.py
Normal file
43
custom_nodes/image_fx.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import ImageFilter
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFX:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"latent": ("LATENT",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", "IMAGE",)
|
||||||
|
|
||||||
|
FUNCTION = "image_fx"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def image_fx(self, vae, latent):
|
||||||
|
tensor_img = vae.decode(latent["samples"])
|
||||||
|
stripped_tensor_img = tensor_img[0]
|
||||||
|
h, w, c = stripped_tensor_img.size()
|
||||||
|
pil_img: Image = T.ToPILImage()(stripped_tensor_img.reshape(c, h, w))
|
||||||
|
pil_img = pil_img.filter(ImageFilter.ModeFilter(2))
|
||||||
|
new_stripped_tensor_img = T.PILToTensor()(pil_img) / 255.0
|
||||||
|
new_tensor_img = new_stripped_tensor_img.reshape(1, h, w, c)
|
||||||
|
pixels = nodes.VAEEncode.vae_encode_crop_pixels(new_tensor_img)
|
||||||
|
new_latent = vae.encode(pixels[:, :, :, :3])
|
||||||
|
return ({"samples": new_latent}, new_tensor_img)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageFX": ImageFX
|
||||||
|
}
|
||||||
42
custom_nodes/test_generator.py
Normal file
42
custom_nodes/test_generator.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerator:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.testID = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"testId": ("STRING",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "test_generator"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
TESTID = 0
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(s, clip, testId=None):
|
||||||
|
# intValue = random.randint(0, 100)
|
||||||
|
# value = str(intValue)
|
||||||
|
if TestGenerator.TESTID < 2:
|
||||||
|
TestGenerator.TESTID += 1
|
||||||
|
return str(TestGenerator.TESTID)
|
||||||
|
|
||||||
|
def test_generator(self, clip, testId=None):
|
||||||
|
tokens = clip.tokenize("test")
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestGenerator": TestGenerator
|
||||||
|
}
|
||||||
48
custom_nodes/test_node.py
Normal file
48
custom_nodes/test_node.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from transformers.models import clip
|
||||||
|
|
||||||
|
|
||||||
|
class TestNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"int_field": ("INT", {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0, #Minimum value
|
||||||
|
"max": 4096, #Maximum value
|
||||||
|
"step": 64, #Slider's step
|
||||||
|
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||||
|
}),
|
||||||
|
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}),
|
||||||
|
"print_to_screen": (["enable", "disable"],),
|
||||||
|
"string_field": ("STRING", {
|
||||||
|
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||||
|
"default": "dong!"
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "test"
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def test(self, clip, image, string_field, int_field, float_field, print_to_screen):
|
||||||
|
if print_to_screen == "enable":
|
||||||
|
print(f"""Your input contains:
|
||||||
|
string_field aka input text: {string_field}
|
||||||
|
int_field: {int_field}
|
||||||
|
float_field: {float_field}
|
||||||
|
""")
|
||||||
|
#do some processing on the image, in this example I just invert it
|
||||||
|
image = 0.5 - image
|
||||||
|
tokens = clip.tokenize("test")
|
||||||
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestNode": TestNode,
|
||||||
|
"TestNode2": TestNode,
|
||||||
|
}
|
||||||
50
custom_nodes/test_sampler.py
Normal file
50
custom_nodes/test_sampler.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import latent_preview
|
||||||
|
from custom_nodes.debug_model import DebugModel
|
||||||
|
from nodes import common_ksampler
|
||||||
|
|
||||||
|
|
||||||
|
class TestSampler:
|
||||||
|
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||||
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc",
|
||||||
|
"uni_pc_bh2"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model": ("MODEL",),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||||
|
"sampler_name": (TestSampler.SAMPLERS,),
|
||||||
|
"scheduler": (TestSampler.SCHEDULERS,),
|
||||||
|
"positive": ("CONDITIONING",),
|
||||||
|
"negative": ("CONDITIONING",),
|
||||||
|
"latent_image": ("LATENT",),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"mixture": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT", "LATENT", "LATENT")
|
||||||
|
FUNCTION = "sample"
|
||||||
|
|
||||||
|
CATEGORY = "inflamously"
|
||||||
|
|
||||||
|
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0,
|
||||||
|
mixture=1.0):
|
||||||
|
a_val = common_ksampler(model, seed, round(steps / 2), cfg, sampler_name, scheduler, positive, negative,
|
||||||
|
latent_image, denoise=.8)
|
||||||
|
b_val = common_ksampler(model, seed + 1, round(steps / 2), cfg, sampler_name, scheduler, positive, negative,
|
||||||
|
a_val[0], denoise=.9)
|
||||||
|
x_val = common_ksampler(model, seed + 2, round(steps), cfg, sampler_name, scheduler, positive, negative, b_val[0], denoise=denoise)
|
||||||
|
return (x_val[0], a_val[0], b_val[0])
|
||||||
|
|
||||||
|
# c_val = [{"samples": None}]
|
||||||
|
# c_val[0]["samples"] = (a_val[0]["samples"] * 0.5 * (1.0 - mixture)) + (b_val[0]["samples"] * 0.5 * (0.0 + mixture))
|
||||||
|
# c_val[0]["samples"] = (a_val[0]["samples"] * (1.0 - mixture)) - (b_val[0]["samples"] * (0.0 + mixture))
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestSampler": TestSampler
|
||||||
|
}
|
||||||
44
execution.py
44
execution.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
@ -11,6 +12,8 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from message_queue import PromptExecutorMessageQueue
|
||||||
|
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
@ -41,6 +44,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
|
# TODO: Called to execute Node's function
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = False
|
input_is_list = False
|
||||||
@ -72,6 +76,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
|
# TODO: Executes impl node or custom_nodes function
|
||||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -117,6 +122,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
|
# TODO: Retrieves Node Input Data to be passed onto
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
@ -152,6 +158,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||||
outputs[unique_id] = output_data
|
outputs[unique_id] = output_data
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
|
success, error = validate_output_ui_data(server, unique_id, prompt_id, class_type, executed, output_ui)
|
||||||
|
if not success:
|
||||||
|
raise Exception("Output UI Error: {}".format(error))
|
||||||
|
if "UI_TEMPLATE" in output_ui:
|
||||||
|
template_file = os.path.join(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)))), "custom_nodes", output_ui["UI_TEMPLATE"][0])
|
||||||
|
with open(template_file, "r") as f:
|
||||||
|
output_ui["UI_TEMPLATE"] = f.read()
|
||||||
|
if len(output_ui["UI_TEMPLATE"]) <= 0: raise Exception("UI_TEMPLATE cannot be empty!")
|
||||||
outputs_ui[unique_id] = output_ui
|
outputs_ui[unique_id] = output_ui
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -194,6 +208,14 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
|
|
||||||
return (True, None, None)
|
return (True, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_output_ui_data(server, node_id, prompt_id, class_type, executed, output_ui):
|
||||||
|
try:
|
||||||
|
json.dumps(output_ui)
|
||||||
|
return True, None
|
||||||
|
except Exception as error:
|
||||||
|
return False, error
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item):
|
def recursive_will_execute(prompt, outputs, current_item):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
@ -230,7 +252,9 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
except:
|
except Exception as e:
|
||||||
|
# TODO: IMPL Frontend UI
|
||||||
|
print("Exception occured on IS_CHANGED: {}".format(e))
|
||||||
to_delete = True
|
to_delete = True
|
||||||
else:
|
else:
|
||||||
is_changed = prompt[unique_id]['is_changed']
|
is_changed = prompt[unique_id]['is_changed']
|
||||||
@ -267,6 +291,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
# TODO: Caches node instances
|
||||||
self.object_storage = {}
|
self.object_storage = {}
|
||||||
self.outputs_ui = {}
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
@ -382,6 +407,23 @@ class PromptExecutor:
|
|||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
|
|
||||||
|
def prompt_message_loop(self):
|
||||||
|
# TODO: Better refactor, is it good here?
|
||||||
|
try:
|
||||||
|
while PromptExecutorMessageQueue.get_prompt_queue().not_empty:
|
||||||
|
msg, data = PromptExecutorMessageQueue.get_prompt_queue().get(False)
|
||||||
|
if msg:
|
||||||
|
if msg == "NODE_REFRESH":
|
||||||
|
for refreshed_node_list in data:
|
||||||
|
for refreshed_node in refreshed_node_list:
|
||||||
|
keys = self.object_storage.keys()
|
||||||
|
for nodeKey in keys:
|
||||||
|
if nodeKey[1] == refreshed_node["name"]:
|
||||||
|
self.object_storage.pop(nodeKey)
|
||||||
|
break
|
||||||
|
print("PROMPT_EXECUTOR_MESSAGE_EVENT: {}".format(msg))
|
||||||
|
except queue.Empty:
|
||||||
|
pass # Just ignore
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
def validate_inputs(prompt, item, validated):
|
||||||
|
|||||||
20
main.bat
Normal file
20
main.bat
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
@echo off
|
||||||
|
|
||||||
|
:: Deactivate the virtual environment
|
||||||
|
call .\venv\Scripts\deactivate.bat
|
||||||
|
|
||||||
|
:: Activate the virtual environment
|
||||||
|
call .\venv\Scripts\activate.bat
|
||||||
|
set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib
|
||||||
|
|
||||||
|
:: If the exit code is 0, run the kohya_gui.py script with the command-line arguments
|
||||||
|
if %errorlevel% equ 0 (
|
||||||
|
REM Check if the batch was started via double-click
|
||||||
|
IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" (
|
||||||
|
REM echo This script was started by double clicking.
|
||||||
|
cmd /k python.exe main.py --auto-launch %*
|
||||||
|
) ELSE (
|
||||||
|
REM echo This script was started from a command prompt.
|
||||||
|
python.exe main.py --auto-launch %*
|
||||||
|
)
|
||||||
|
)
|
||||||
2
main.py
2
main.py
@ -86,11 +86,13 @@ def cuda_malloc_warning():
|
|||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
# TODO: Prompt handler of each node recursively
|
||||||
def prompt_worker(q, server):
|
def prompt_worker(q, server):
|
||||||
e = execution.PromptExecutor(server)
|
e = execution.PromptExecutor(server)
|
||||||
while True:
|
while True:
|
||||||
item, item_id = q.get()
|
item, item_id = q.get()
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
|
e.prompt_message_loop()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
q.task_done(item_id, e.outputs_ui)
|
q.task_done(item_id, e.outputs_ui)
|
||||||
|
|||||||
11
message_queue.py
Normal file
11
message_queue.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
|
# This queue is loop-driven by second created thread that processes additional prompt messages
|
||||||
|
|
||||||
|
class PromptExecutorMessageQueue:
|
||||||
|
__PROMPT_QUEUE = queue.LifoQueue()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_prompt_queue():
|
||||||
|
return PromptExecutorMessageQueue.__PROMPT_QUEUE
|
||||||
54
nodes.py
54
nodes.py
@ -14,6 +14,8 @@ from PIL.PngImagePlugin import PngInfo
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
from message_queue import PromptExecutorMessageQueue
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||||
|
|
||||||
|
|
||||||
@ -1707,8 +1709,22 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
|
|
||||||
EXTENSION_WEB_DIRS = {}
|
EXTENSION_WEB_DIRS = {}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomNodeData:
|
||||||
|
def __init__(self, name="", reloaded=False):
|
||||||
|
self.name = name
|
||||||
|
self.reloaded = reloaded
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return self.__dict__
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Validate custom node since it throws bad errors.
|
||||||
def load_custom_node(module_path, ignore=set()):
|
def load_custom_node(module_path, ignore=set()):
|
||||||
module_name = os.path.basename(module_path)
|
module_name = os.path.basename(module_path)
|
||||||
|
module_reload = False
|
||||||
|
loaded_custom_node_data = []
|
||||||
|
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
sp = os.path.splitext(module_path)
|
sp = os.path.splitext(module_path)
|
||||||
module_name = sp[0]
|
module_name = sp[0]
|
||||||
@ -1720,6 +1736,10 @@ def load_custom_node(module_path, ignore=set()):
|
|||||||
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||||
module_dir = module_path
|
module_dir = module_path
|
||||||
|
|
||||||
|
if module_name in sys.modules:
|
||||||
|
print("Module reload: {}".format(module_name))
|
||||||
|
module_reload = True
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(module_spec)
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
module_spec.loader.exec_module(module)
|
module_spec.loader.exec_module(module)
|
||||||
@ -1731,14 +1751,18 @@ def load_custom_node(module_path, ignore=set()):
|
|||||||
|
|
||||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||||
for name in module.NODE_CLASS_MAPPINGS:
|
for name in module.NODE_CLASS_MAPPINGS:
|
||||||
if name not in ignore:
|
if module_reload or name not in ignore:
|
||||||
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
|
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
|
||||||
|
# TODO: Allow multiple params for node without overwriting
|
||||||
|
loaded_custom_node_data.append(
|
||||||
|
CustomNodeData(name, module_reload).dict()
|
||||||
|
)
|
||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
return True
|
return True, loaded_custom_node_data
|
||||||
else:
|
else:
|
||||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||||
return False
|
return False, loaded_custom_node_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||||
@ -1748,29 +1772,39 @@ def load_custom_nodes():
|
|||||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
node_import_times = []
|
node_import_times = []
|
||||||
|
node_data = {}
|
||||||
for custom_node_path in node_paths:
|
for custom_node_path in node_paths:
|
||||||
possible_modules = os.listdir(custom_node_path)
|
possible_modules = os.listdir(custom_node_path)
|
||||||
if "__pycache__" in possible_modules:
|
|
||||||
possible_modules.remove("__pycache__")
|
|
||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
if os.path.basename(module_path).startswith("__") or os.path.splitext(module_path)[1] != ".py" or not os.path.isfile(module_path):
|
||||||
|
print("Invalid module found: {}".format(possible_module))
|
||||||
|
continue
|
||||||
if module_path.endswith(".disabled"): continue
|
if module_path.endswith(".disabled"): continue
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = load_custom_node(module_path, base_node_names)
|
success, custom_node_data = load_custom_node(module_path, base_node_names)
|
||||||
|
if success:
|
||||||
|
node_data[module_path] = custom_node_data
|
||||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
|
|
||||||
|
print("Custom Loaded Nodes Data: {}".format(node_data))
|
||||||
|
|
||||||
if len(node_import_times) > 0:
|
if len(node_import_times) > 0:
|
||||||
print("\nImport times for custom nodes:")
|
print("\nImport times for custom nodes:")
|
||||||
for n in sorted(node_import_times):
|
for n in sorted(node_import_times):
|
||||||
if n[2]:
|
if n[2]:
|
||||||
import_message = ""
|
import_message = " (SUCCESS)"
|
||||||
else:
|
else:
|
||||||
import_message = " (IMPORT FAILED)"
|
import_message = " (IMPORT FAILED)"
|
||||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
# Notify other prompt loop thread of refresh
|
||||||
|
refreshed_nodes_list = [custom_node for custom_node in [custom_node_list for _, custom_node_list in node_data.items()]]
|
||||||
|
PromptExecutorMessageQueue.get_prompt_queue().put(["NODE_REFRESH", refreshed_nodes_list])
|
||||||
|
|
||||||
|
return node_data
|
||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
@ -1781,4 +1815,6 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
||||||
|
# TODO: How to load without pushing this complete addon
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes/comfyui_controlnet_aux"), "__init__.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers==4.26.1
|
||||||
safetensors>=0.3.0
|
safetensors>=0.3.0
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
@ -10,3 +10,4 @@ Pillow
|
|||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
psutil
|
psutil
|
||||||
|
clip-interrogator==0.5.4
|
||||||
@ -30,7 +30,6 @@ from comfy.cli_args import args
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
@ -424,6 +423,10 @@ class PromptServer():
|
|||||||
out[node_class] = node_info(node_class)
|
out[node_class] = node_info(node_class)
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
|
@routes.get("/custom_nodes")
|
||||||
|
async def get_load_nodes(request):
|
||||||
|
return web.json_response(nodes.load_custom_nodes())
|
||||||
|
|
||||||
@routes.get("/history")
|
@routes.get("/history")
|
||||||
async def get_history(request):
|
async def get_history(request):
|
||||||
return web.json_response(self.prompt_queue.get_history())
|
return web.json_response(self.prompt_queue.get_history())
|
||||||
@ -446,7 +449,7 @@ class PromptServer():
|
|||||||
print("got prompt")
|
print("got prompt")
|
||||||
resp_code = 200
|
resp_code = 200
|
||||||
out_string = ""
|
out_string = ""
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
json_data = self.trigger_on_prompt(json_data)
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
|
||||||
if "number" in json_data:
|
if "number" in json_data:
|
||||||
|
|||||||
@ -181,6 +181,11 @@ class ComfyApi extends EventTarget {
|
|||||||
return await resp.json();
|
return await resp.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getCustomNodes() {
|
||||||
|
const resp = await this.fetchApi("/custom_nodes", { cache: "no-store" });
|
||||||
|
return await resp.json()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue
|
* @param {number} number The index at which to queue the prompt, passing -1 will insert the prompt at the front of the queue
|
||||||
|
|||||||
@ -1055,13 +1055,15 @@ export class ComfyApp {
|
|||||||
this.graph.setDirtyCanvas(true, false);
|
this.graph.setDirtyCanvas(true, false);
|
||||||
delete this.nodePreviewImages[this.runningNodeId]
|
delete this.nodePreviewImages[this.runningNodeId]
|
||||||
});
|
});
|
||||||
|
// TODO: UI Update
|
||||||
api.addEventListener("executed", ({ detail }) => {
|
api.addEventListener("executed", ({ detail }) => {
|
||||||
this.nodeOutputs[detail.node] = detail.output;
|
this.nodeOutputs[detail.node] = detail.output;
|
||||||
const node = this.graph.getNodeById(detail.node);
|
const node = this.graph.getNodeById(detail.node);
|
||||||
if (node) {
|
if (node) {
|
||||||
if (node.onExecuted)
|
if (node.onExecuted)
|
||||||
node.onExecuted(detail.output);
|
node.onExecuted(detail.output);
|
||||||
|
|
||||||
|
this.updateNode(node, detail);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1181,8 +1183,8 @@ export class ComfyApp {
|
|||||||
this.loadGraphData();
|
this.loadGraphData();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save current workflow automatically
|
// Save current workflow automatically
|
||||||
setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000);
|
setInterval(async () => await this.saveWorkflow(), 1000);
|
||||||
|
|
||||||
this.#addDrawNodeHandler();
|
this.#addDrawNodeHandler();
|
||||||
this.#addDrawGroupsHandler();
|
this.#addDrawGroupsHandler();
|
||||||
@ -1195,6 +1197,9 @@ export class ComfyApp {
|
|||||||
await this.#invokeExtensionsAsync("setup");
|
await this.#invokeExtensionsAsync("setup");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async saveWorkflow() {
|
||||||
|
localStorage.setItem("workflow", JSON.stringify(this.graph.serialize()));
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Registers nodes with the graph
|
* Registers nodes with the graph
|
||||||
*/
|
*/
|
||||||
@ -1646,11 +1651,23 @@ export class ComfyApp {
|
|||||||
this.extensions.push(extension);
|
this.extensions.push(extension);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Refresh combo list on whole nodes
|
* Refresh combo list on whole nodes
|
||||||
*/
|
* @param {HTMLElement} button
|
||||||
async refreshComboInNodes() {
|
*/
|
||||||
|
async refreshComboInNodes(button) {
|
||||||
|
if (button.getAttribute("disabled")) {
|
||||||
|
// Do not allow multiple refreshes
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
button.setAttribute("disabled", true);
|
||||||
|
// Reload custom node modules under custom_nodes
|
||||||
|
const customNodeData = await api.getCustomNodes();
|
||||||
|
|
||||||
|
// Reload combobox
|
||||||
const defs = await api.getNodeDefs();
|
const defs = await api.getNodeDefs();
|
||||||
|
LiteGraph.clearRegisteredTypes();
|
||||||
|
await this.registerNodesFromDefs(defs);
|
||||||
|
|
||||||
for(let nodeNum in this.graph._nodes) {
|
for(let nodeNum in this.graph._nodes) {
|
||||||
const node = this.graph._nodes[nodeNum];
|
const node = this.graph._nodes[nodeNum];
|
||||||
@ -1674,6 +1691,8 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
button.removeAttribute("disabled");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1686,6 +1705,25 @@ export class ComfyApp {
|
|||||||
this.lastExecutionError = null;
|
this.lastExecutionError = null;
|
||||||
this.runningNodeId = null;
|
this.runningNodeId = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update Node UI Based on node state data
|
||||||
|
* TODO: Better Idea than just plain impl into App?
|
||||||
|
*/
|
||||||
|
updateNode(node, detail) {
|
||||||
|
switch (node.type) {
|
||||||
|
case "DebugNode":
|
||||||
|
const {texts} = detail.output
|
||||||
|
if (texts !== undefined && texts.length > 0) {
|
||||||
|
node.title = texts[0].substring(0, 16);
|
||||||
|
node.widgets[0].value = texts[0]
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "DebugCond":
|
||||||
|
console.log(detail)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const app = new ComfyApp();
|
export const app = new ComfyApp();
|
||||||
|
|||||||
@ -751,7 +751,7 @@ export class ComfyUI {
|
|||||||
$el("button", {
|
$el("button", {
|
||||||
id: "comfy-refresh-button",
|
id: "comfy-refresh-button",
|
||||||
textContent: "Refresh",
|
textContent: "Refresh",
|
||||||
onclick: () => app.refreshComboInNodes()
|
onclick: () => app.refreshComboInNodes(document.getElementById("comfy-refresh-button"))
|
||||||
}),
|
}),
|
||||||
$el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}),
|
$el("button", {id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace()}),
|
||||||
$el("button", {
|
$el("button", {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user