More in-progress changes for adding an event system and eventlistener, print node, and saving/loading and muxing latents

This commit is contained in:
InconsolableCellist 2023-04-28 15:20:23 -06:00
parent 3adb344fe3
commit e42746498d
4 changed files with 263 additions and 33 deletions

View File

@ -22,8 +22,8 @@ def load_torch_file(ckpt, safe_load=False):
return sd
def save_latent(samples, filename_prefix):
filename = os.path.join(folder_paths.get_output_directory(), (filename_prefix + "_latent.npy"))
def save_latent(samples, filename):
filename = os.path.join(folder_paths.get_output_directory(), filename)
np.save(filename, samples)

View File

@ -12,6 +12,25 @@ import nodes
import comfy.model_management
class EventDispatcher:
def __init__(self):
self.listeners = {}
def subscribe(self, event_name, callback):
if event_name not in self.listeners:
self.listeners[event_name] = []
self.listeners[event_name].append(callback)
def unsubscribe(self, event_name, callback):
if event_name in self.listeners:
self.listeners[event_name].remove(callback)
def emit(self, event_name, event_data, *args, **kwargs):
if event_name in self.listeners:
for callback in self.listeners[event_name]:
callback(event_name, event_data, *args, **kwargs)
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -40,7 +59,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = unique_id
return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, event_dispatcher):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
@ -55,16 +74,21 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, event_dispatcher)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id)
obj = class_def()
obj = class_def(event_dispatcher=event_dispatcher)
nodes.before_node_execution()
event_dispatcher.emit("node_started", {"event_type": "node_started", "unique_id": unique_id,
"class_type": prompt[unique_id]["class_type"]})
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
event_dispatcher.emit("node_finished", {"event_type": "node_started", "unique_id": unique_id,
"class_type": prompt[unique_id]["class_type"]})
if "ui" in outputs[unique_id]:
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
@ -142,10 +166,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
return to_delete
class PromptExecutor:
def __init__(self, server):
def __init__(self, server, event_dispatcher):
self.outputs = {}
self.old_prompt = {}
self.server = server
self.event_dispatcher = event_dispatcher
def execute(self, prompt, extra_data={}):
nodes.interrupt_processing(False)
@ -192,7 +217,8 @@ class PromptExecutor:
except:
valid = False
if valid:
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed,
self.event_dispatcher)
except Exception as e:
print(traceback.format_exc())
to_delete = []

View File

@ -21,15 +21,14 @@ if __name__ == "__main__":
import yaml
import execution
import folder_paths
import server
from nodes import init_custom_nodes
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
def prompt_worker(q, server, event_dispatcher):
e = execution.PromptExecutor(server, event_dispatcher)
while True:
item, item_id = q.get()
e.execute(item[-2], item[-1])
@ -93,7 +92,8 @@ if __name__ == "__main__":
server.add_routes()
hijack_progress(server)
threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
event_dispatcher = execution.EventDispatcher()
threading.Thread(target=prompt_worker, daemon=True, args=(q,server, event_dispatcher,)).start()
address = args.listen

248
nodes.py
View File

@ -36,6 +36,9 @@ def interrupt_processing(value=True):
MAX_RESOLUTION=8192
class CLIPTextEncode:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
@ -48,6 +51,8 @@ class CLIPTextEncode:
return ([[clip.encode(text), {}]], )
class ConditioningCombine:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@ -60,6 +65,8 @@ class ConditioningCombine:
return (conditioning_1 + conditioning_2, )
class ConditioningSetArea:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -86,8 +93,9 @@ class ConditioningSetArea:
return (c, )
class VAEDecode:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.device = device
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -101,7 +109,8 @@ class VAEDecode:
return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.event_dispatcher = event_dispatcher
self.device = device
@classmethod
@ -116,7 +125,8 @@ class VAEDecodeTiled:
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.event_dispatcher = event_dispatcher
self.device = device
@classmethod
@ -138,8 +148,9 @@ class VAEEncode:
class VAEEncodeTiled:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.device = device
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -158,8 +169,9 @@ class VAEEncodeTiled:
return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.device = device
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -192,6 +204,9 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class CheckpointLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ),
@ -207,6 +222,8 @@ class CheckpointLoader:
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@ -222,6 +239,8 @@ class CheckpointLoaderSimple:
return out
class DiffusersLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(cls):
paths = []
@ -246,6 +265,8 @@ class DiffusersLoader:
class unCLIPCheckpointLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
@ -261,6 +282,8 @@ class unCLIPCheckpointLoader:
return out
class CLIPSetLastLayer:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP", ),
@ -296,6 +319,8 @@ class LoraLoader:
return (model_lora, clip_lora)
class TomePatchModel:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
@ -312,6 +337,8 @@ class TomePatchModel:
return (m, )
class VAELoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
@ -327,6 +354,8 @@ class VAELoader:
return (vae,)
class ControlNetLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}}
@ -342,6 +371,8 @@ class ControlNetLoader:
return (controlnet,)
class DiffControlNetLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
@ -359,6 +390,8 @@ class DiffControlNetLoader:
class ControlNetApply:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -385,6 +418,8 @@ class ControlNetApply:
return (c, )
class CLIPLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
@ -400,6 +435,8 @@ class CLIPLoader:
return (clip,)
class CLIPVisionLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ),
@ -415,6 +452,8 @@ class CLIPVisionLoader:
return (clip_vision,)
class CLIPVisionEncode:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
@ -430,6 +469,8 @@ class CLIPVisionEncode:
return (output,)
class StyleModelLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}}
@ -446,6 +487,8 @@ class StyleModelLoader:
class StyleModelApply:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -466,6 +509,8 @@ class StyleModelApply:
return (c, )
class unCLIPConditioning:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
@ -492,6 +537,8 @@ class unCLIPConditioning:
return (c, )
class GLIGENLoader:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
@ -507,6 +554,8 @@ class GLIGENLoader:
return (gligen,)
class GLIGENTextBoxApply:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
@ -538,8 +587,9 @@ class GLIGENTextBoxApply:
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
def __init__(self, device="cpu", event_dispatcher=None):
self.device = device
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -557,6 +607,8 @@ class EmptyLatentImage:
class LatentFromBatch:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -576,6 +628,8 @@ class LatentFromBatch:
return (s,)
class LatentUpscale:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@ -596,21 +650,29 @@ class LatentUpscale:
return (s,)
class SaveLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"filename_prefix": ("STRING", {"default": "ComfyUI"})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "save"
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
CATEGORY = "latent"
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"filename": ("STRING", {"default": "ComfyUI_latent.npy"})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "save"
def save(self, samples, filename_prefix):
s = samples.copy()
comfy.utils.save_latent(samples["samples"], filename_prefix)
return (samples,)
CATEGORY = "latent"
def save(self, samples, filename):
s = samples.copy()
comfy.utils.save_latent(samples["samples"], filename)
@clas
return (samples,)
class LoadLatent:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "filename": ("STRING", {"default": "ComfyUI_latent.npy"})}}
@ -624,6 +686,37 @@ class LoadLatent:
derp = ({"samples": comfy.utils.load_latent(filename)},)
return derp
class MuxLatent:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent1": ("LATENT",),
"latent2": ("LATENT",),
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "interpolate"
CATEGORY = "latent"
def interpolate(self, latent1, latent2, weight):
# Ensure the latents have the same shape
if latent1["samples"].shape != latent2["samples"].shape:
raise ValueError("Latents must have the same shape")
# Interpolate the latents using the weight
interpolated_latent = latent1["samples"] * (1 - weight) + latent2["samples"] * weight
return ({"samples": interpolated_latent},)
class LatentRotate:
@classmethod
def INPUT_TYPES(s):
@ -649,6 +742,8 @@ class LatentRotate:
return (s,)
class LatentFlip:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -711,6 +806,8 @@ class LatentComposite:
return (samples_out,)
class LatentCrop:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -754,6 +851,8 @@ class LatentCrop:
return (s,)
class SetLatentNoiseMask:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
@ -791,6 +890,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
return (out, )
class KSampler:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required":
@ -815,6 +916,8 @@ class KSampler:
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
class KSamplerAdvanced:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required":
@ -848,9 +951,10 @@ class KSamplerAdvanced:
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
class SaveImage:
def __init__(self):
def __init__(self, event_dispatcher=None):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -924,8 +1028,9 @@ class SaveImage:
return { "ui": { "images": results } }
class PreviewImage(SaveImage):
def __init__(self):
def __init__(self, event_dispatcher=None):
self.output_dir = folder_paths.get_temp_directory()
self.event_dispatcher = event_dispatcher
self.type = "temp"
@classmethod
@ -936,6 +1041,8 @@ class PreviewImage(SaveImage):
}
class LoadImage:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
@ -976,6 +1083,8 @@ class LoadImage:
return True
class LoadImageMask:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod
def INPUT_TYPES(s):
@ -1024,6 +1133,8 @@ class LoadImageMask:
return True
class ImageScale:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@ -1045,6 +1156,8 @@ class ImageScale:
return (s,)
class ImageInvert:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -1061,6 +1174,8 @@ class ImageInvert:
class ImagePadForOutpaint:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
@ -1123,18 +1238,101 @@ class ImagePadForOutpaint:
return (new_image, mask)
class FrameCounter:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {"required": { "frame": ("INT", {"default": 0})}}
return {
"required": {
"frame": ("INT", {"default": 0}),
"fired": ("BOOL", {"default": False}),
},
}
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
RETURN_TYPES = ("INT",)
RETURN_TYPES = ("text",)
FUNCTION = "frame_counter"
CATEGORY = "operations"
def frame_counter(self, frame):
def frame_counter(self, frame, fired):
if fired:
frame += 1
return (frame,)
class EventListener:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"event_type": (["node_started", "node_finished"],),
"class_type": ("STRING", {"default": "KSampler"})
},
}
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
RETURN_TYPES = ("BOOL",)
RETURN_NAMES = ("fired",)
FUNCTION = "listen"
CATEGORY = "Events"
def listen(self, event_type, class_type):
self._fired = False
def event_listener(event, event_data):
print(f"Got an event of type {event_data['event_type']} with data {event_data}")
if (event_data["event_type"] == event_type and event_data["class_type"] == class_type):
self._fired = True
self.event_dispatcher.subscribe(event_type, event_listener)
return (self._fired,)
class PrinterNode:
def __init__(self, event_dispatcher):
self.event_dispatcher = event_dispatcher
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
"optional": {
"text": ("text",),
"latent": ("LATENT",),
}
}
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
RETURN_TYPES = ()
FUNCTION = "print_value"
CATEGORY = "operations"
OUTPUT_NODE = True
def print_value(self, text=None, latent=None):
if latent is not None:
latent_hash = hashlib.sha256(latent["samples"].cpu().numpy().tobytes()).hexdigest()
print(f"Latent hash: {latent_hash}")
print(np.array2string(latent["samples"].cpu().numpy(), separator=', '))
print(text)
return {"ui": {"": text}}
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
@ -1184,6 +1382,9 @@ NODE_CLASS_MAPPINGS = {
"CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader,
"FrameCounter": FrameCounter,
"PrinterNode": PrinterNode,
"EventListener": EventListener,
"MuxLatent": MuxLatent,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1236,6 +1437,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"VAEEncodeTiled": "VAE Encode (Tiled)",
# operations
"FrameCounter": "Frame Counter",
"PrinterNode": "Print",
"EventListener": "Event Listener",
"MuxLatent": "Mux Latent",
}
def load_custom_node(module_path):