From e42746498d98fa723f6ada655bfcb8ed87e04bd1 Mon Sep 17 00:00:00 2001 From: InconsolableCellist <23345188+InconsolableCellist@users.noreply.github.com> Date: Fri, 28 Apr 2023 15:20:23 -0600 Subject: [PATCH] More in-progress changes for adding an event system and eventlistener, print node, and saving/loading and muxing latents --- comfy/utils.py | 4 +- execution.py | 36 ++++++- main.py | 8 +- nodes.py | 248 ++++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 263 insertions(+), 33 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 0e768b0d7..46bc325c6 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -22,8 +22,8 @@ def load_torch_file(ckpt, safe_load=False): return sd -def save_latent(samples, filename_prefix): - filename = os.path.join(folder_paths.get_output_directory(), (filename_prefix + "_latent.npy")) +def save_latent(samples, filename): + filename = os.path.join(folder_paths.get_output_directory(), filename) np.save(filename, samples) diff --git a/execution.py b/execution.py index c19c10bc6..f624490ef 100644 --- a/execution.py +++ b/execution.py @@ -12,6 +12,25 @@ import nodes import comfy.model_management + +class EventDispatcher: + def __init__(self): + self.listeners = {} + + def subscribe(self, event_name, callback): + if event_name not in self.listeners: + self.listeners[event_name] = [] + self.listeners[event_name].append(callback) + + def unsubscribe(self, event_name, callback): + if event_name in self.listeners: + self.listeners[event_name].remove(callback) + + def emit(self, event_name, event_data, *args, **kwargs): + if event_name in self.listeners: + for callback in self.listeners[event_name]: + callback(event_name, event_data, *args, **kwargs) + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -40,7 +59,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = unique_id return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, event_dispatcher): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -55,16 +74,21 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, event_dispatcher) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id server.send_sync("executing", { "node": unique_id }, server.client_id) - obj = class_def() + obj = class_def(event_dispatcher=event_dispatcher) nodes.before_node_execution() + event_dispatcher.emit("node_started", {"event_type": "node_started", "unique_id": unique_id, + "class_type": prompt[unique_id]["class_type"]}) outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) + event_dispatcher.emit("node_finished", {"event_type": "node_started", "unique_id": unique_id, + "class_type": prompt[unique_id]["class_type"]}) + if "ui" in outputs[unique_id]: if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) @@ -142,10 +166,11 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item return to_delete class PromptExecutor: - def __init__(self, server): + def __init__(self, server, event_dispatcher): self.outputs = {} self.old_prompt = {} self.server = server + self.event_dispatcher = event_dispatcher def execute(self, prompt, extra_data={}): nodes.interrupt_processing(False) @@ -192,7 +217,8 @@ class PromptExecutor: except: valid = False if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, + self.event_dispatcher) except Exception as e: print(traceback.format_exc()) to_delete = [] diff --git a/main.py b/main.py index 02c700ebc..880d0901c 100644 --- a/main.py +++ b/main.py @@ -21,15 +21,14 @@ if __name__ == "__main__": import yaml - import execution import folder_paths import server from nodes import init_custom_nodes -def prompt_worker(q, server): - e = execution.PromptExecutor(server) +def prompt_worker(q, server, event_dispatcher): + e = execution.PromptExecutor(server, event_dispatcher) while True: item, item_id = q.get() e.execute(item[-2], item[-1]) @@ -93,7 +92,8 @@ if __name__ == "__main__": server.add_routes() hijack_progress(server) - threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() + event_dispatcher = execution.EventDispatcher() + threading.Thread(target=prompt_worker, daemon=True, args=(q,server, event_dispatcher,)).start() address = args.listen diff --git a/nodes.py b/nodes.py index 111ef8712..7eeea9650 100644 --- a/nodes.py +++ b/nodes.py @@ -36,6 +36,9 @@ def interrupt_processing(value=True): MAX_RESOLUTION=8192 class CLIPTextEncode: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + @classmethod def INPUT_TYPES(s): return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} @@ -48,6 +51,8 @@ class CLIPTextEncode: return ([[clip.encode(text), {}]], ) class ConditioningCombine: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} @@ -60,6 +65,8 @@ class ConditioningCombine: return (conditioning_1 + conditioning_2, ) class ConditioningSetArea: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -86,8 +93,9 @@ class ConditioningSetArea: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): self.device = device + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -101,7 +109,8 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): + self.event_dispatcher = event_dispatcher self.device = device @classmethod @@ -116,7 +125,8 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): + self.event_dispatcher = event_dispatcher self.device = device @classmethod @@ -138,8 +148,9 @@ class VAEEncode: class VAEEncodeTiled: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): self.device = device + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -158,8 +169,9 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): self.device = device + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -192,6 +204,9 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) class CheckpointLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), @@ -207,6 +222,8 @@ class CheckpointLoader: return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CheckpointLoaderSimple: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), @@ -222,6 +239,8 @@ class CheckpointLoaderSimple: return out class DiffusersLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(cls): paths = [] @@ -246,6 +265,8 @@ class DiffusersLoader: class unCLIPCheckpointLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), @@ -261,6 +282,8 @@ class unCLIPCheckpointLoader: return out class CLIPSetLastLayer: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP", ), @@ -296,6 +319,8 @@ class LoraLoader: return (model_lora, clip_lora) class TomePatchModel: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -312,6 +337,8 @@ class TomePatchModel: return (m, ) class VAELoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} @@ -327,6 +354,8 @@ class VAELoader: return (vae,) class ControlNetLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "control_net_name": (folder_paths.get_filename_list("controlnet"), )}} @@ -342,6 +371,8 @@ class ControlNetLoader: return (controlnet,) class DiffControlNetLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -359,6 +390,8 @@ class DiffControlNetLoader: class ControlNetApply: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -385,6 +418,8 @@ class ControlNetApply: return (c, ) class CLIPLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), @@ -400,6 +435,8 @@ class CLIPLoader: return (clip,) class CLIPVisionLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"), ), @@ -415,6 +452,8 @@ class CLIPVisionLoader: return (clip_vision,) class CLIPVisionEncode: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), @@ -430,6 +469,8 @@ class CLIPVisionEncode: return (output,) class StyleModelLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "style_model_name": (folder_paths.get_filename_list("style_models"), )}} @@ -446,6 +487,8 @@ class StyleModelLoader: class StyleModelApply: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -466,6 +509,8 @@ class StyleModelApply: return (c, ) class unCLIPConditioning: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -492,6 +537,8 @@ class unCLIPConditioning: return (c, ) class GLIGENLoader: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}} @@ -507,6 +554,8 @@ class GLIGENLoader: return (gligen,) class GLIGENTextBoxApply: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_to": ("CONDITIONING", ), @@ -538,8 +587,9 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self, device="cpu"): + def __init__(self, device="cpu", event_dispatcher=None): self.device = device + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -557,6 +607,8 @@ class EmptyLatentImage: class LatentFromBatch: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -576,6 +628,8 @@ class LatentFromBatch: return (s,) class LatentUpscale: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -596,21 +650,29 @@ class LatentUpscale: return (s,) class SaveLatent: - @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "filename_prefix": ("STRING", {"default": "ComfyUI"})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "save" + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher - CATEGORY = "latent" + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "filename": ("STRING", {"default": "ComfyUI_latent.npy"})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "save" - def save(self, samples, filename_prefix): - s = samples.copy() - comfy.utils.save_latent(samples["samples"], filename_prefix) - return (samples,) + CATEGORY = "latent" + + def save(self, samples, filename): + s = samples.copy() + comfy.utils.save_latent(samples["samples"], filename) + + @clas + + return (samples,) class LoadLatent: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "filename": ("STRING", {"default": "ComfyUI_latent.npy"})}} @@ -624,6 +686,37 @@ class LoadLatent: derp = ({"samples": comfy.utils.load_latent(filename)},) return derp + + +class MuxLatent: + + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "latent1": ("LATENT",), + "latent2": ("LATENT",), + "weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "interpolate" + + CATEGORY = "latent" + + def interpolate(self, latent1, latent2, weight): + # Ensure the latents have the same shape + if latent1["samples"].shape != latent2["samples"].shape: + raise ValueError("Latents must have the same shape") + + # Interpolate the latents using the weight + interpolated_latent = latent1["samples"] * (1 - weight) + latent2["samples"] * weight + + return ({"samples": interpolated_latent},) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -649,6 +742,8 @@ class LatentRotate: return (s,) class LatentFlip: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -711,6 +806,8 @@ class LatentComposite: return (samples_out,) class LatentCrop: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -754,6 +851,8 @@ class LatentCrop: return (s,) class SetLatentNoiseMask: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -791,6 +890,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, return (out, ) class KSampler: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": @@ -815,6 +916,8 @@ class KSampler: return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) class KSamplerAdvanced: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): return {"required": @@ -848,9 +951,10 @@ class KSamplerAdvanced: return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) class SaveImage: - def __init__(self): + def __init__(self, event_dispatcher=None): self.output_dir = folder_paths.get_output_directory() self.type = "output" + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -924,8 +1028,9 @@ class SaveImage: return { "ui": { "images": results } } class PreviewImage(SaveImage): - def __init__(self): + def __init__(self, event_dispatcher=None): self.output_dir = folder_paths.get_temp_directory() + self.event_dispatcher = event_dispatcher self.type = "temp" @classmethod @@ -936,6 +1041,8 @@ class PreviewImage(SaveImage): } class LoadImage: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() @@ -976,6 +1083,8 @@ class LoadImage: return True class LoadImageMask: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): @@ -1024,6 +1133,8 @@ class LoadImageMask: return True class ImageScale: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -1045,6 +1156,8 @@ class ImageScale: return (s,) class ImageInvert: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -1061,6 +1174,8 @@ class ImageInvert: class ImagePadForOutpaint: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher @classmethod def INPUT_TYPES(s): @@ -1123,18 +1238,101 @@ class ImagePadForOutpaint: return (new_image, mask) class FrameCounter: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + @classmethod def INPUT_TYPES(s): - return {"required": { "frame": ("INT", {"default": 0})}} + return { + "required": { + "frame": ("INT", {"default": 0}), + "fired": ("BOOL", {"default": False}), + }, + } + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return True - RETURN_TYPES = ("INT",) + RETURN_TYPES = ("text",) FUNCTION = "frame_counter" CATEGORY = "operations" - def frame_counter(self, frame): + def frame_counter(self, frame, fired): + if fired: + frame += 1 return (frame,) +class EventListener: + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "event_type": (["node_started", "node_finished"],), + "class_type": ("STRING", {"default": "KSampler"}) + }, + } + + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return True + + RETURN_TYPES = ("BOOL",) + RETURN_NAMES = ("fired",) + + FUNCTION = "listen" + + CATEGORY = "Events" + + def listen(self, event_type, class_type): + self._fired = False + + def event_listener(event, event_data): + print(f"Got an event of type {event_data['event_type']} with data {event_data}") + if (event_data["event_type"] == event_type and event_data["class_type"] == class_type): + self._fired = True + + self.event_dispatcher.subscribe(event_type, event_listener) + + return (self._fired,) + +class PrinterNode: + + def __init__(self, event_dispatcher): + self.event_dispatcher = event_dispatcher + + @classmethod + def INPUT_TYPES(s): + return { + "required": {}, + "optional": { + "text": ("text",), + "latent": ("LATENT",), + } + } + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return True + + RETURN_TYPES = () + FUNCTION = "print_value" + CATEGORY = "operations" + OUTPUT_NODE = True + + def print_value(self, text=None, latent=None): + if latent is not None: + latent_hash = hashlib.sha256(latent["samples"].cpu().numpy().tobytes()).hexdigest() + print(f"Latent hash: {latent_hash}") + print(np.array2string(latent["samples"].cpu().numpy(), separator=', ')) + + + print(text) + return {"ui": {"": text}} + + NODE_CLASS_MAPPINGS = { "KSampler": KSampler, "CheckpointLoaderSimple": CheckpointLoaderSimple, @@ -1184,6 +1382,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, "FrameCounter": FrameCounter, + "PrinterNode": PrinterNode, + "EventListener": EventListener, + "MuxLatent": MuxLatent, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -1236,6 +1437,9 @@ NODE_DISPLAY_NAME_MAPPINGS = { "VAEEncodeTiled": "VAE Encode (Tiled)", # operations "FrameCounter": "Frame Counter", + "PrinterNode": "Print", + "EventListener": "Event Listener", + "MuxLatent": "Mux Latent", } def load_custom_node(module_path):