From 181a9bf26d4445e160645f6c81dc2ee29e7b6a08 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 9 Jul 2025 08:18:04 +0800 Subject: [PATCH 1/5] Support Multi Image-Caption dataset in lora training node (#8819) * initial impl of multi img/text dataset * Update nodes_train.py * Support Kohya-ss structure --- comfy_extras/nodes_train.py | 125 +++++++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index fbff01010..17caf5ad5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -75,7 +75,7 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None"): +def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): """Utility function to load and process a list of images. Args: @@ -90,7 +90,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None"): raise ValueError("No valid images found in input") output_images = [] - w, h = None, None for file in image_files: image_path = os.path.join(input_dir, file) @@ -206,6 +205,103 @@ class LoadImageSetFromFolderNode: return (output_tensor,) +class LoadImageTextSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), + "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + "width": ( + IO.INT, + { + "default": -1, + "min": -1, + "max": 10000, + "step": 1, + "tooltip": "The width to resize the images to. -1 means use the original width.", + }, + ), + "height": ( + IO.INT, + { + "default": -1, + "min": -1, + "max": 10000, + "step": 1, + "tooltip": "The height to resize the images to. -1 means use the original height.", + }, + ) + }, + } + + RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images and caption from a directory for training." + + def load_images(self, folder, clip, resize_method, width=None, height=None): + if clip is None: + raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") + + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend([ + os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) + ] * repeat) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") + for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + width = width if width != -1 else None + height = height if height != -1 else None + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + + logging.info(f"Encoding captions from {sub_input_dir}.") + conditions = [] + empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + for text in captions: + if text == "": + conditions.append(empty_cond) + tokens = clip.tokenize(text) + conditions.extend(clip.encode_from_tokens_scheduled(tokens)) + logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") + return (output_tensor, conditions) + + def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -381,6 +477,13 @@ class TrainLoraNode: latents = latents["samples"].to(dtype) num_images = latents.shape[0] + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + positive = positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) with torch.inference_mode(False): lora_sd = {} @@ -474,6 +577,7 @@ class TrainLoraNode: # setup models for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): patch(m) + mp.model.requires_grad_(False) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) # Setup sampler and guider like in test script @@ -486,7 +590,6 @@ class TrainLoraNode: ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input - ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() # yoland: this currently resize to the first image in the dataset @@ -495,21 +598,21 @@ class TrainLoraNode: try: for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): # Generate random sigma - sigma = mp.model.model_sampling.percent_to_sigma( + sigmas = [mp.model.model_sampling.percent_to_sigma( torch.rand((1,)).item() - ) - sigma = torch.tensor([sigma]) + ) for _ in range(min(batch_size, num_images))] + sigmas = torch.tensor(sigmas) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) indices = torch.randperm(num_images)[:batch_size] - ss.sample( - noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} - ) + batch_latent = latents[indices].clone() + guider.set_conds([positive[i] for i in indices]) # Set conditioning from input + guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed) finally: for m in mp.model.modules(): unpatch(m) - del ss, train_sampler, optimizer + del train_sampler, optimizer torch.cuda.empty_cache() for adapter in all_weight_adapters: @@ -697,6 +800,7 @@ NODE_CLASS_MAPPINGS = { "SaveLoRANode": SaveLoRA, "LoraModelLoader": LoraModelLoader, "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, "LossGraphNode": LossGraphNode, } @@ -705,5 +809,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveLoRANode": "Save LoRA Weights", "LoraModelLoader": "Load LoRA Model", "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", "LossGraphNode": "Plot Loss Graph", } From 5612670ee48ce500aab98e362b3372ab06d1d659 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 9 Jul 2025 00:45:48 -0700 Subject: [PATCH 2/5] Remove unmaintained notebook. (#8845) --- README.md | 4 - notebooks/comfyui_colab.ipynb | 322 ---------------------------------- 2 files changed, 326 deletions(-) delete mode 100644 notebooks/comfyui_colab.ipynb diff --git a/README.md b/README.md index ba8892b17..0e021a687 100644 --- a/README.md +++ b/README.md @@ -178,10 +178,6 @@ If you have trouble extracting it, right click the file -> properties -> unblock See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. -## Jupyter Notebook - -To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb) - ## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb deleted file mode 100644 index 5560b5ff9..000000000 --- a/notebooks/comfyui_colab.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "aaaaaaaaaa" - }, - "source": [ - "Git clone the repo and install the requirements. (ignore the pip errors about protobuf)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bbbbbbbbbb" - }, - "outputs": [], - "source": [ - "#@title Environment Setup\n", - "\n", - "\n", - "OPTIONS = {}\n", - "\n", - "USE_GOOGLE_DRIVE = False #@param {type:\"boolean\"}\n", - "UPDATE_COMFY_UI = True #@param {type:\"boolean\"}\n", - "WORKSPACE = 'ComfyUI'\n", - "OPTIONS['USE_GOOGLE_DRIVE'] = USE_GOOGLE_DRIVE\n", - "OPTIONS['UPDATE_COMFY_UI'] = UPDATE_COMFY_UI\n", - "\n", - "if OPTIONS['USE_GOOGLE_DRIVE']:\n", - " !echo \"Mounting Google Drive...\"\n", - " %cd /\n", - " \n", - " from google.colab import drive\n", - " drive.mount('/content/drive')\n", - "\n", - " WORKSPACE = \"/content/drive/MyDrive/ComfyUI\"\n", - " %cd /content/drive/MyDrive\n", - "\n", - "![ ! -d $WORKSPACE ] && echo -= Initial setup ComfyUI =- && git clone https://github.com/comfyanonymous/ComfyUI\n", - "%cd $WORKSPACE\n", - "\n", - "if OPTIONS['UPDATE_COMFY_UI']:\n", - " !echo -= Updating ComfyUI =-\n", - " !git pull\n", - "\n", - "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cccccccccc" - }, - "source": [ - "Download some models/checkpoints/vae or custom comfyui nodes (uncomment the commands for the ones you want)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dddddddddd" - }, - "outputs": [], - "source": [ - "# Checkpoints\n", - "\n", - "### SDXL\n", - "### I recommend these workflow examples: https://comfyanonymous.github.io/ComfyUI_examples/sdxl/\n", - "\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n", - "\n", - "# SDXL ReVision\n", - "#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n", - "\n", - "# SD1.5\n", - "!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "# SD2\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors -P ./models/checkpoints/\n", - "\n", - "# Some SD1.5 anime style\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix2/AbyssOrangeMix2_hard.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A1_orangemixs.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/Models/AbyssOrangeMix3/AOM3A3_orangemixs.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/anything-v3-fp16-pruned.safetensors -P ./models/checkpoints/\n", - "\n", - "# Waifu Diffusion 1.5 (anime style SD2.x 768-v)\n", - "#!wget -c https://huggingface.co/waifu-diffusion/wd-1-5-beta3/resolve/main/wd-illusion-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "\n", - "# unCLIP models\n", - "#!wget -c https://huggingface.co/comfyanonymous/illuminatiDiffusionV1_v11_unCLIP/resolve/main/illuminatiDiffusionV1_v11-unclip-h-fp16.safetensors -P ./models/checkpoints/\n", - "#!wget -c https://huggingface.co/comfyanonymous/wd-1.5-beta2_unCLIP/resolve/main/wd-1-5-beta2-aesthetic-unclip-h-fp16.safetensors -P ./models/checkpoints/\n", - "\n", - "\n", - "# VAE\n", - "!wget -c https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors -P ./models/vae/\n", - "#!wget -c https://huggingface.co/WarriorMama777/OrangeMixs/resolve/main/VAEs/orangemix.vae.pt -P ./models/vae/\n", - "#!wget -c https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt -P ./models/vae/\n", - "\n", - "\n", - "# Loras\n", - "#!wget -c https://civitai.com/api/download/models/10350 -O ./models/loras/theovercomer8sContrastFix_sd21768.safetensors #theovercomer8sContrastFix SD2.x 768-v\n", - "#!wget -c https://civitai.com/api/download/models/10638 -O ./models/loras/theovercomer8sContrastFix_sd15.safetensors #theovercomer8sContrastFix SD1.x\n", - "#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors -P ./models/loras/ #SDXL offset noise lora\n", - "\n", - "\n", - "# T2I-Adapter\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_openpose_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_color_sd14v1.pth -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd14v1.pth -P ./models/controlnet/\n", - "\n", - "# T2I Styles Model\n", - "#!wget -c https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_style_sd14v1.pth -P ./models/style_models/\n", - "\n", - "# CLIPVision model (needed for styles model)\n", - "#!wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin -O ./models/clip_vision/clip_vit14.bin\n", - "\n", - "\n", - "# ControlNet\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_ip2p_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11e_sd15_shuffle_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_canny_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11f1p_sd15_depth_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_inpaint_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_lineart_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_mlsd_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_normalbae_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_openpose_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_scribble_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_seg_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15_softedge_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n", - "\n", - "# ControlNet SDXL\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n", - "#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n", - "\n", - "# Controlnet Preprocessor nodes by Fannovel16\n", - "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", - "\n", - "\n", - "# GLIGEN\n", - "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n", - "\n", - "\n", - "# ESRGAN upscale model\n", - "#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n", - "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", - "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kkkkkkkkkkkkkkk" - }, - "source": [ - "### Run ComfyUI with cloudflared (Recommended Way)\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jjjjjjjjjjjjjj" - }, - "outputs": [], - "source": [ - "!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n", - "!dpkg -i cloudflared-linux-amd64.deb\n", - "\n", - "import subprocess\n", - "import threading\n", - "import time\n", - "import socket\n", - "import urllib.request\n", - "\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch cloudflared (if it gets stuck here cloudflared is having issues)\\n\")\n", - "\n", - " p = subprocess.Popen([\"cloudflared\", \"tunnel\", \"--url\", \"http://127.0.0.1:{}\".format(port)], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", - " for line in p.stderr:\n", - " l = line.decode()\n", - " if \"trycloudflare.com \" in l:\n", - " print(\"This is the URL to access ComfyUI:\", l[l.find(\"http\"):], end='')\n", - " #print(l, end='')\n", - "\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kkkkkkkkkkkkkk" - }, - "source": [ - "### Run ComfyUI with localtunnel\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jjjjjjjjjjjjj" - }, - "outputs": [], - "source": [ - "!npm install -g localtunnel\n", - "\n", - "import threading\n", - "\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", - "\n", - " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", - " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", - " for line in p.stdout:\n", - " print(line.decode(), end='')\n", - "\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gggggggggg" - }, - "source": [ - "### Run ComfyUI with colab iframe (use only in case the previous way with localtunnel doesn't work)\n", - "\n", - "You should see the ui appear in an iframe. If you get a 403 error, it's your firefox settings or an extension that's messing things up.\n", - "\n", - "If you want to open it in another window use the link.\n", - "\n", - "Note that some UI features like live image previews won't work because the colab iframe blocks websockets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hhhhhhhhhh" - }, - "outputs": [], - "source": [ - "import threading\n", - "def iframe_thread(port):\n", - " while True:\n", - " time.sleep(0.5)\n", - " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", - " result = sock.connect_ex(('127.0.0.1', port))\n", - " if result == 0:\n", - " break\n", - " sock.close()\n", - " from google.colab import output\n", - " output.serve_kernel_port_as_iframe(port, height=1024)\n", - " print(\"to open it in a window you can open this link here:\")\n", - " output.serve_kernel_port_as_window(port)\n", - "\n", - "threading.Thread(target=iframe_thread, daemon=True, args=(8188,)).start()\n", - "\n", - "!python main.py --dont-print-server" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} From 1205afc708d963d160f38c1d6613a384ddf6c564 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 9 Jul 2025 23:41:22 +0800 Subject: [PATCH 3/5] Better training loop implementation (#8820) --- comfy_extras/nodes_train.py | 122 +++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 42 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 17caf5ad5..3d05fdab5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -23,38 +23,78 @@ from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters +def make_batch_extra_option_dict(d, indicies, full_size=None): + new_dict = {} + for k, v in d.items(): + newv = v + if isinstance(v, dict): + newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) + elif isinstance(v, torch.Tensor): + if full_size is None or v.size(0) == full_size: + newv = v[indicies] + elif isinstance(v, (list, tuple)) and len(v) == full_size: + newv = [v[i] for i in indicies] + new_dict[k] = newv + return new_dict + + class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None): + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback + self.batch_size = batch_size + self.total_steps = total_steps + self.seed = seed + self.training_dtype = training_dtype def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - self.optimizer.zero_grad() - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) - latent = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(sigmas), - torch.zeros_like(noise, requires_grad=True), - latent_image, - False - ) + cond = model_wrap.conds["positive"] + dataset_size = sigmas.size(0) + torch.cuda.empty_cache() + for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) + indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - # Ensure model is in training mode and computing gradients - # x0 pred - denoised = model_wrap(noise, sigmas, **extra_args) - try: - loss = self.loss_fn(denoised, latent.clone()) - except RuntimeError as e: - if "does not require grad and does not have a grad_fn" in str(e): - logging.info("WARNING: This is likely due to the model is loaded in inference mode.") - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - self.optimizer.step() - # torch.cuda.memory._dump_snapshot("trainn.pickle") - # torch.cuda.memory._record_memory_history(enabled=None) + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, + batch_noise, + batch_latent, + False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False + ) + + model_wrap.conds["positive"] = [ + cond[i] for i in indicies + ] + batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) + loss = self.loss_fn(x0_pred, x0) + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + self.optimizer.step() + self.optimizer.zero_grad() + torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -584,36 +624,34 @@ class TrainLoraNode: loss_map = {"loss": []} def loss_callback(loss): loss_map["loss"].append(loss) - pbar.set_postfix({"loss": f"{loss:.4f}"}) train_sampler = TrainSampler( - criterion, optimizer, loss_callback=loss_callback + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + total_steps=steps, + seed=seed, + training_dtype=dtype ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input - # yoland: this currently resize to the first image in the dataset - # Training loop - torch.cuda.empty_cache() try: - for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - # Generate random sigma - sigmas = [mp.model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(batch_size, num_images))] - sigmas = torch.tensor(sigmas) - - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) - - indices = torch.randperm(num_images)[:batch_size] - batch_latent = latents[indices].clone() - guider.set_conds([positive[i] for i in indices]) # Set conditioning from input - guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed) + # Generate dummy sigmas and noise + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed + ) finally: for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer - torch.cuda.empty_cache() for adapter in all_weight_adapters: adapter.requires_grad_(False) From 1fd306824d35bf2669f6be46fadab37efd7081c4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 9 Jul 2025 22:03:27 -0700 Subject: [PATCH 4/5] Add warning to catch torch import mistakes. (#8852) --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index d488c0f4c..0b1987ef4 100644 --- a/main.py +++ b/main.py @@ -127,6 +127,9 @@ if __name__ == "__main__": import cuda_malloc +if 'torch' in sys.modules: + logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + import comfy.utils import execution From 2b653e8c18f18792e7f080df611c8a35f1d0fdf4 Mon Sep 17 00:00:00 2001 From: guill Date: Thu, 10 Jul 2025 11:46:19 -0700 Subject: [PATCH 5/5] Support for async node functions (#8830) * Support for async execution functions This commit adds support for node execution functions defined as async. When a node's execution function is defined as async, we can continue executing other nodes while it is processing. Standard uses of `await` should "just work", but people will still have to be careful if they spawn actual threads. Because torch doesn't really have async/await versions of functions, this won't particularly help with most locally-executing nodes, but it does work for e.g. web requests to other machines. In addition to the execute function, the `VALIDATE_INPUTS` and `check_lazy_status` functions can also be defined as async, though we'll only resolve one node at a time right now for those. * Add the execution model tests to CI * Add a missing file It looks like this got caught by .gitignore? There's probably a better place to put it, but I'm not sure what that is. * Add the websocket library for automated tests * Add additional tests for async error cases Also fixes one bug that was found when an async function throws an error after being scheduled on a task. * Add a feature flags message to reduce bandwidth We now only send 1 preview message of the latest type the client can support. We'll add a console warning when the client fails to send a feature flags message at some point in the future. * Add async tests to CI * Don't actually add new tests in this PR Will do it in a separate PR * Resolve unit test in GPU-less runner * Just remove the tests that GHA can't handle * Change line endings to UNIX-style * Avoid loading model_management.py so early Because model_management.py has a top-level `logging.info`, we have to be careful not to import that file before we call `setup_logging`. If we do, we end up having the default logging handler registered in addition to our custom one. --- comfy/utils.py | 5 +- comfy_api/feature_flags.py | 69 +++ comfy_execution/caching.py | 53 +-- comfy_execution/graph.py | 20 +- comfy_execution/progress.py | 347 +++++++++++++++ comfy_execution/utils.py | 46 ++ execution.py | 135 ++++-- main.py | 32 +- protocol.py | 7 + server.py | 98 ++++- tests-unit/feature_flags_test.py | 98 +++++ tests-unit/requirements.txt | 1 + tests-unit/websocket_feature_flags_test.py | 77 ++++ tests/inference/extra_model_paths.yaml | 2 +- tests/inference/test_async_nodes.py | 410 ++++++++++++++++++ tests/inference/test_execution.py | 65 ++- .../testing_nodes/testing-pack/__init__.py | 49 ++- .../testing-pack/async_test_nodes.py | 343 +++++++++++++++ .../testing-pack/specific_tests.py | 136 ++++++ 19 files changed, 1898 insertions(+), 95 deletions(-) create mode 100644 comfy_api/feature_flags.py create mode 100644 comfy_execution/progress.py create mode 100644 comfy_execution/utils.py create mode 100644 protocol.py create mode 100644 tests-unit/feature_flags_test.py create mode 100644 tests-unit/websocket_feature_flags_test.py create mode 100644 tests/inference/test_async_nodes.py create mode 100644 tests/inference/testing_nodes/testing-pack/async_test_nodes.py diff --git a/comfy/utils.py b/comfy/utils.py index 47981d8f6..f8e01f713 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -998,11 +998,12 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function class ProgressBar: - def __init__(self, total): + def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK self.total = total self.current = 0 self.hook = PROGRESS_BAR_HOOK + self.node_id = node_id def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1011,7 +1012,7 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview) + self.hook(self.current, self.total, preview, node_id=self.node_id) def update(self, value): self.update_absolute(self.current + value) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py new file mode 100644 index 000000000..0d4389a6e --- /dev/null +++ b/comfy_api/feature_flags.py @@ -0,0 +1,69 @@ +""" +Feature flags module for ComfyUI WebSocket protocol negotiation. + +This module handles capability negotiation between frontend and backend, +allowing graceful protocol evolution while maintaining backward compatibility. +""" + +from typing import Any, Dict + +from comfy.cli_args import args + +# Default server capabilities +SERVER_FEATURE_FLAGS: Dict[str, Any] = { + "supports_preview_metadata": True, + "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes +} + + +def get_connection_feature( + sockets_metadata: Dict[str, Dict[str, Any]], + sid: str, + feature_name: str, + default: Any = False +) -> Any: + """ + Get a feature flag value for a specific connection. + + Args: + sockets_metadata: Dictionary of socket metadata + sid: Session ID of the connection + feature_name: Name of the feature to check + default: Default value if feature not found + + Returns: + Feature value or default if not found + """ + if sid not in sockets_metadata: + return default + + return sockets_metadata[sid].get("feature_flags", {}).get(feature_name, default) + + +def supports_feature( + sockets_metadata: Dict[str, Dict[str, Any]], + sid: str, + feature_name: str +) -> bool: + """ + Check if a connection supports a specific feature. + + Args: + sockets_metadata: Dictionary of socket metadata + sid: Session ID of the connection + feature_name: Name of the feature to check + + Returns: + Boolean indicating if feature is supported + """ + return get_connection_feature(sockets_metadata, sid, feature_name, False) is True + + +def get_server_features() -> Dict[str, Any]: + """ + Get the server's feature flags. + + Returns: + Dictionary of server feature flags + """ + return SERVER_FEATURE_FLAGS.copy() diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index dbb37b89f..41224ce3b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,6 +1,7 @@ import itertools from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt +from abc import ABC, abstractmethod import nodes @@ -16,12 +17,13 @@ def include_unique_id_in_input(class_type: str) -> bool: NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] -class CacheKeySet: +class CacheKeySet(ABC): def __init__(self, dynprompt, node_ids, is_changed_cache): self.keys = {} self.subcache_keys = {} - def add_keys(self, node_ids): + @abstractmethod + async def add_keys(self, node_ids): raise NotImplementedError() def all_node_ids(self): @@ -60,9 +62,8 @@ class CacheKeySetID(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed_cache): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt - self.add_keys(node_ids) - def add_keys(self, node_ids): + async def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue @@ -77,37 +78,36 @@ class CacheKeySetInputSignature(CacheKeySet): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache - self.add_keys(node_ids) def include_node_id_in_input(self) -> bool: return False - def add_keys(self, node_ids): + async def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue if not self.dynprompt.has_node(node_id): continue node = self.dynprompt.get_node(node_id) - self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) self.subcache_keys[node_id] = (node_id, node["class_type"]) - def get_node_signature(self, dynprompt, node_id): + async def get_node_signature(self, dynprompt, node_id): signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) - signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: - signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) return to_hashable(signature) - def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, self.is_changed_cache.get(node_id)] + signature = [class_type, await self.is_changed_cache.get(node_id)] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) inputs = node["inputs"] @@ -150,9 +150,10 @@ class BasicCache: self.cache = {} self.subcaches = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + await self.cache_key_set.add_keys(node_ids) self.is_changed_cache = is_changed_cache self.initialized = True @@ -201,13 +202,13 @@ class BasicCache: else: return None - def _ensure_subcache(self, node_id, children_ids): + async def _ensure_subcache(self, node_id, children_ids): subcache_key = self.cache_key_set.get_subcache_key(node_id) subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) self.subcaches[subcache_key] = subcache - subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache def _get_subcache(self, node_id): @@ -259,10 +260,10 @@ class HierarchicalCache(BasicCache): assert cache is not None cache._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) assert cache is not None - return cache._ensure_subcache(node_id, children_ids) + return await cache._ensure_subcache(node_id, children_ids) class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): @@ -273,8 +274,8 @@ class LRUCache(BasicCache): self.used_generation = {} self.children = {} - def set_prompt(self, dynprompt, node_ids, is_changed_cache): - super().set_prompt(dynprompt, node_ids, is_changed_cache) + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + await super().set_prompt(dynprompt, node_ids, is_changed_cache) self.generation += 1 for node_id in node_ids: self._mark_used(node_id) @@ -303,11 +304,11 @@ class LRUCache(BasicCache): self._mark_used(node_id) return self._set_immediate(node_id, value) - def ensure_subcache_for(self, node_id, children_ids): + async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes - super()._ensure_subcache(node_id, children_ids) + await super()._ensure_subcache(node_id, children_ids) - self.cache_key_set.add_keys(children_ids) + await self.cache_key_set.add_keys(children_ids) self._mark_used(node_id) cache_key = self.cache_key_set.get_data_key(node_id) self.children[cache_key] = [] @@ -337,7 +338,7 @@ class DependencyAwareCache(BasicCache): self.ancestors = {} # Maps node_id -> set of ancestor node_ids self.executed_nodes = set() # Tracks nodes that have been executed - def set_prompt(self, dynprompt, node_ids, is_changed_cache): + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): """ Clear the entire cache and rebuild the dependency graph. @@ -354,7 +355,7 @@ class DependencyAwareCache(BasicCache): self.executed_nodes.clear() # Call the parent method to initialize the cache with the new prompt - super().set_prompt(dynprompt, node_ids, is_changed_cache) + await super().set_prompt(dynprompt, node_ids, is_changed_cache) # Rebuild the dependency graph self._build_dependency_graph(dynprompt, node_ids) @@ -405,7 +406,7 @@ class DependencyAwareCache(BasicCache): """ return self._get_immediate(node_id) - def ensure_subcache_for(self, node_id, children_ids): + async def ensure_subcache_for(self, node_id, children_ids): """ Ensure a subcache exists for a node and update dependencies. @@ -416,7 +417,7 @@ class DependencyAwareCache(BasicCache): Returns: The subcache object for the node. """ - subcache = super()._ensure_subcache(node_id, children_ids) + subcache = await super()._ensure_subcache(node_id, children_ids) for child_id in children_ids: self.descendants[node_id].add(child_id) self.ancestors[child_id].add(node_id) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index a2799b52e..c79243e1e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Type, Literal import nodes +import asyncio from comfy_execution.graph_utils import is_link from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions @@ -100,6 +101,8 @@ class TopologicalSort: self.pendingNodes = {} self.blockCount = {} # Number of nodes this node is directly blocked by self.blocking = {} # Which nodes are blocked by this node + self.externalBlocks = 0 + self.unblockedEvent = asyncio.Event() def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -153,6 +156,16 @@ class TopologicalSort: for link in links: self.add_strong_link(*link) + def add_external_block(self, node_id): + assert node_id in self.blockCount, "Can't add external block to a node that isn't pending" + self.externalBlocks += 1 + self.blockCount[node_id] += 1 + def unblock(): + self.externalBlocks -= 1 + self.blockCount[node_id] -= 1 + self.unblockedEvent.set() + return unblock + def is_cached(self, node_id): return False @@ -181,11 +194,16 @@ class ExecutionList(TopologicalSort): def is_cached(self, node_id): return self.output_cache.get(node_id) is not None - def stage_node_execution(self): + async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): return None, None, None available = self.get_ready_nodes() + while len(available) == 0 and self.externalBlocks > 0: + # Wait for an external block to be released + await self.unblockedEvent.wait() + self.unblockedEvent.clear() + available = self.get_ready_nodes() if len(available) == 0: cycled_nodes = self.get_nodes_in_cycle() # Because cycles composed entirely of static nodes are caught during initial validation, diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py new file mode 100644 index 000000000..5645b3e3c --- /dev/null +++ b/comfy_execution/progress.py @@ -0,0 +1,347 @@ +from typing import TypedDict, Dict, Optional +from typing_extensions import override +from PIL import Image +from enum import Enum +from abc import ABC +from tqdm import tqdm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy_execution.graph import DynamicPrompt +from protocol import BinaryEventTypes +from comfy_api import feature_flags + + +class NodeState(Enum): + Pending = "pending" + Running = "running" + Finished = "finished" + Error = "error" + + +class NodeProgressState(TypedDict): + """ + A class to represent the state of a node's progress. + """ + + state: NodeState + value: float + max: float + + +class ProgressHandler(ABC): + """ + Abstract base class for progress handlers. + Progress handlers receive progress updates and display them in various ways. + """ + + def __init__(self, name: str): + self.name = name + self.enabled = True + + def set_registry(self, registry: "ProgressRegistry"): + pass + + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + """Called when a node starts processing""" + pass + + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: Optional[Image.Image] = None, + ): + """Called when a node's progress is updated""" + pass + + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + """Called when a node finishes processing""" + pass + + def reset(self): + """Called when the progress registry is reset""" + pass + + def enable(self): + """Enable this handler""" + self.enabled = True + + def disable(self): + """Disable this handler""" + self.enabled = False + + +class CLIProgressHandler(ProgressHandler): + """ + Handler that displays progress using tqdm progress bars in the CLI. + """ + + def __init__(self): + super().__init__("cli") + self.progress_bars: Dict[str, tqdm] = {} + + @override + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Create a new tqdm progress bar + if node_id not in self.progress_bars: + self.progress_bars[node_id] = tqdm( + total=state["max"], + desc=f"Node {node_id}", + unit="steps", + leave=True, + position=len(self.progress_bars), + ) + + @override + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: Optional[Image.Image] = None, + ): + # Handle case where start_handler wasn't called + if node_id not in self.progress_bars: + self.progress_bars[node_id] = tqdm( + total=max_value, + desc=f"Node {node_id}", + unit="steps", + leave=True, + position=len(self.progress_bars), + ) + self.progress_bars[node_id].update(value) + else: + # Update existing progress bar + if max_value != self.progress_bars[node_id].total: + self.progress_bars[node_id].total = max_value + # Calculate the update amount (difference from current position) + current_position = self.progress_bars[node_id].n + update_amount = value - current_position + if update_amount > 0: + self.progress_bars[node_id].update(update_amount) + + @override + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Complete and close the progress bar if it exists + if node_id in self.progress_bars: + # Ensure the bar shows 100% completion + remaining = state["max"] - self.progress_bars[node_id].n + if remaining > 0: + self.progress_bars[node_id].update(remaining) + self.progress_bars[node_id].close() + del self.progress_bars[node_id] + + @override + def reset(self): + # Close all progress bars + for bar in self.progress_bars.values(): + bar.close() + self.progress_bars.clear() + + +class WebUIProgressHandler(ProgressHandler): + """ + Handler that sends progress updates to the WebUI via WebSockets. + """ + + def __init__(self, server_instance): + super().__init__("webui") + self.server_instance = server_instance + + def set_registry(self, registry: "ProgressRegistry"): + self.registry = registry + + def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): + """Send the current progress state to the client""" + if self.server_instance is None: + return + + # Only send info for non-pending nodes + active_nodes = { + node_id: { + "value": state["value"], + "max": state["max"], + "state": state["state"].value, + "node_id": node_id, + "prompt_id": prompt_id, + "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), + "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), + "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), + } + for node_id, state in nodes.items() + if state["state"] != NodeState.Pending + } + + # Send a combined progress_state message with all node states + self.server_instance.send_sync( + "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes} + ) + + @override + def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + + @override + def update_handler( + self, + node_id: str, + value: float, + max_value: float, + state: NodeProgressState, + prompt_id: str, + image: Optional[Image.Image] = None, + ): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + if image: + # Only send new format if client supports it + if feature_flags.supports_feature( + self.server_instance.sockets_metadata, + self.server_instance.client_id, + "supports_preview_metadata", + ): + metadata = { + "node_id": node_id, + "prompt_id": prompt_id, + "display_node_id": self.registry.dynprompt.get_display_node_id( + node_id + ), + "parent_node_id": self.registry.dynprompt.get_parent_node_id( + node_id + ), + "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), + } + self.server_instance.send_sync( + BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, + (image, metadata), + self.server_instance.client_id, + ) + + @override + def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): + # Send progress state of all nodes + if self.registry: + self._send_progress_state(prompt_id, self.registry.nodes) + + +class ProgressRegistry: + """ + Registry that maintains node progress state and notifies registered handlers. + """ + + def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): + self.prompt_id = prompt_id + self.dynprompt = dynprompt + self.nodes: Dict[str, NodeProgressState] = {} + self.handlers: Dict[str, ProgressHandler] = {} + + def register_handler(self, handler: ProgressHandler) -> None: + """Register a progress handler""" + self.handlers[handler.name] = handler + + def unregister_handler(self, handler_name: str) -> None: + """Unregister a progress handler""" + if handler_name in self.handlers: + # Allow handler to clean up resources + self.handlers[handler_name].reset() + del self.handlers[handler_name] + + def enable_handler(self, handler_name: str) -> None: + """Enable a progress handler""" + if handler_name in self.handlers: + self.handlers[handler_name].enable() + + def disable_handler(self, handler_name: str) -> None: + """Disable a progress handler""" + if handler_name in self.handlers: + self.handlers[handler_name].disable() + + def ensure_entry(self, node_id: str) -> NodeProgressState: + """Ensure a node entry exists""" + if node_id not in self.nodes: + self.nodes[node_id] = NodeProgressState( + state=NodeState.Pending, value=0, max=1 + ) + return self.nodes[node_id] + + def start_progress(self, node_id: str) -> None: + """Start progress tracking for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Running + entry["value"] = 0.0 + entry["max"] = 1.0 + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.start_handler(node_id, entry, self.prompt_id) + + def update_progress( + self, node_id: str, value: float, max_value: float, image: Optional[Image.Image] + ) -> None: + """Update progress for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Running + entry["value"] = value + entry["max"] = max_value + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.update_handler( + node_id, value, max_value, entry, self.prompt_id, image + ) + + def finish_progress(self, node_id: str) -> None: + """Finish progress tracking for a node""" + entry = self.ensure_entry(node_id) + entry["state"] = NodeState.Finished + entry["value"] = entry["max"] + + # Notify all enabled handlers + for handler in self.handlers.values(): + if handler.enabled: + handler.finish_handler(node_id, entry, self.prompt_id) + + def reset_handlers(self) -> None: + """Reset all handlers""" + for handler in self.handlers.values(): + handler.reset() + +# Global registry instance +global_progress_registry: ProgressRegistry | None = None + +def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: + global global_progress_registry + + # Reset existing handlers if registry exists + if global_progress_registry is not None: + global_progress_registry.reset_handlers() + + # Create new registry + global_progress_registry = ProgressRegistry(prompt_id, dynprompt) + + +def add_progress_handler(handler: ProgressHandler) -> None: + registry = get_progress_state() + handler.set_registry(registry) + registry.register_handler(handler) + + +def get_progress_state() -> ProgressRegistry: + global global_progress_registry + if global_progress_registry is None: + from comfy_execution.graph import DynamicPrompt + + global_progress_registry = ProgressRegistry( + prompt_id="", dynprompt=DynamicPrompt({}) + ) + return global_progress_registry diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py new file mode 100644 index 000000000..62d32f101 --- /dev/null +++ b/comfy_execution/utils.py @@ -0,0 +1,46 @@ +import contextvars +from typing import Optional, NamedTuple + +class ExecutionContext(NamedTuple): + """ + Context information about the currently executing node. + + Attributes: + node_id: The ID of the currently executing node + list_index: The index in a list being processed (for operations on batches/lists) + """ + prompt_id: str + node_id: str + list_index: Optional[int] + +current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None) + +def get_executing_context() -> Optional[ExecutionContext]: + return current_executing_context.get(None) + +class CurrentNodeContext: + """ + Context manager for setting the current executing node context. + + Sets the current_executing_context on enter and resets it on exit. + + Example: + with CurrentNodeContext(node_id="123", list_index=0): + # Code that should run with the current node context set + process_image() + """ + def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): + self.context = ExecutionContext( + prompt_id= prompt_id, + node_id= node_id, + list_index= list_index + ) + self.token = None + + def __enter__(self): + self.token = current_executing_context.set(self.context) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.token is not None: + current_executing_context.reset(self.token) diff --git a/execution.py b/execution.py index f6006fa12..90cefc023 100644 --- a/execution.py +++ b/execution.py @@ -8,12 +8,14 @@ import time import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional +import asyncio import torch import comfy.model_management import nodes from comfy_execution.caching import ( + BasicCache, CacheKeySetID, CacheKeySetInputSignature, DependencyAwareCache, @@ -28,6 +30,8 @@ from comfy_execution.graph import ( ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input +from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.utils import CurrentNodeContext class ExecutionResult(Enum): @@ -39,12 +43,13 @@ class DuplicateNodeError(Exception): pass class IsChangedCache: - def __init__(self, dynprompt, outputs_cache): + def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): + self.prompt_id = prompt_id self.dynprompt = dynprompt self.outputs_cache = outputs_cache self.is_changed = {} - def get(self, node_id): + async def get(self, node_id): if node_id in self.is_changed: return self.is_changed[node_id] @@ -62,7 +67,8 @@ class IsChangedCache: # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED") + is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: logging.warning("WARNING: {}".format(e)) @@ -164,7 +170,19 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e map_node_over_list = None #Don't hook this please -def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): +async def resolve_map_node_over_list_results(results): + remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] + if len(remaining) == 0: + return [x.result() if isinstance(x, asyncio.Task) else x for x in results] + else: + done, pending = await asyncio.wait(remaining) + for task in done: + exc = task.exception() + if exc is not None: + raise exc + return [x.result() if isinstance(x, asyncio.Task) else x for x in results] + +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -178,7 +196,7 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] - def process_inputs(inputs, index=None, input_is_list=False): + async def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: nodes.before_node_execution() execution_block = None @@ -194,20 +212,37 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut if execution_block is None: if pre_execute_cb is not None and index is not None: pre_execute_cb(index) - results.append(getattr(obj, func)(**inputs)) + f = getattr(obj, func) + if inspect.iscoroutinefunction(f): + async def async_wrapper(f, prompt_id, unique_id, list_index, args): + with CurrentNodeContext(prompt_id, unique_id, list_index): + return await f(**args) + task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) + # Give the task a chance to execute without yielding + await asyncio.sleep(0) + if task.done(): + result = task.result() + results.append(result) + else: + results.append(task) + else: + with CurrentNodeContext(prompt_id, unique_id, index): + result = f(**inputs) + results.append(result) else: results.append(execution_block) if input_is_list: - process_inputs(input_data_all, 0, input_is_list=input_is_list) + await process_inputs(input_data_all, 0, input_is_list=input_is_list) elif max_len_input == 0: - process_inputs({}) + await process_inputs({}) else: for i in range(max_len_input): input_dict = slice_dict(input_data_all, i) - process_inputs(input_dict, i) + await process_inputs(input_dict, i) return results + def merge_result_data(results, obj): # check which outputs need concatenating output = [] @@ -229,11 +264,18 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) + if has_pending_task: + return return_values, {}, False, has_pending_task + output, ui, has_subgraph = get_output_from_returns(return_values, obj) + return output, ui, has_subgraph, False + +def get_output_from_returns(return_values, obj): results = [] uis = [] subgraph_results = [] - return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -267,6 +309,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb else: output = [] ui = dict() + # TODO: Think there's an existing bug here + # If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet. + # They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of + # any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future. if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui, has_subgraph @@ -279,7 +325,7 @@ def format_value(x): else: return str(x) -def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -291,11 +337,26 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + get_progress_state().finish_progress(unique_id) return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - if unique_id in pending_subgraph_results: + if unique_id in pending_async_nodes: + results = [] + for r in pending_async_nodes[unique_id]: + if isinstance(r, asyncio.Task): + try: + results.append(r.result()) + except Exception as ex: + # An async task failed - propagate the exception up + del pending_async_nodes[unique_id] + raise ex + else: + results.append(r) + del pending_async_nodes[unique_id] + output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] for is_subgraph, result in cached_results: @@ -317,6 +378,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: + get_progress_state().start_progress(unique_id) input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id @@ -328,7 +390,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp caches.objects.set(unique_id, obj) if hasattr(obj, "check_lazy_status"): - required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( x not in input_data_all or x in missing_keys @@ -357,8 +420,18 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp else: return block def pre_execute_cb(call_index): + # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + if has_pending_tasks: + pending_async_nodes[unique_id] = output_data + unblock = execution_list.add_external_block(unique_id) + async def await_completion(): + tasks = [x for x in output_data if isinstance(x, asyncio.Task)] + await asyncio.gather(*tasks, return_exceptions=True) + unblock() + asyncio.create_task(await_completion()) + return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -401,7 +474,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: - cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) + subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) for link in new_output_links: @@ -446,6 +520,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.FAILURE, error_details, ex) + get_progress_state().finish_progress(unique_id) executed.add(unique_id) return (ExecutionResult.SUCCESS, None, None) @@ -500,6 +575,11 @@ class PromptExecutor: self.add_message("execution_error", mes, broadcast=False) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + asyncio_loop = asyncio.new_event_loop() + asyncio.set_event_loop(asyncio_loop) + asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) + + async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -512,9 +592,11 @@ class PromptExecutor: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) - is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) for cache in self.caches.all: - cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) cache.clean_unused() cached_nodes = [] @@ -527,6 +609,7 @@ class PromptExecutor: { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -534,12 +617,13 @@ class PromptExecutor: execution_list.add_node(node_id) while not execution_list.is_empty(): - node_id, error, ex = execution_list.stage_node_execution() + node_id, error, ex = await execution_list.stage_node_execution() if error is not None: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -569,7 +653,7 @@ class PromptExecutor: comfy.model_management.unload_all_models() -def validate_inputs(prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated): unique_id = item if unique_id in validated: return validated[unique_id] @@ -646,7 +730,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = await validate_inputs(prompt_id, prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -771,7 +855,8 @@ def validate_inputs(prompt, item, validated): input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) - ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS") + ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): if r is not True and not isinstance(r, ExecutionBlocker): @@ -804,7 +889,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -def validate_prompt(prompt): +async def validate_prompt(prompt_id, prompt): outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -847,7 +932,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = await validate_inputs(prompt_id, prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: diff --git a/main.py b/main.py index 0b1987ef4..2b4ffafd4 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,9 @@ import itertools import utils.extra_config import logging import sys +from comfy_execution.progress import get_progress_state +from comfy_execution.utils import get_executing_context +from comfy_api import feature_flags if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -134,7 +137,7 @@ import comfy.utils import execution import server -from server import BinaryEventTypes +from protocol import BinaryEventTypes import nodes import comfy.model_management import comfyui_version @@ -230,15 +233,34 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() ) - def hijack_progress(server_instance): - def hook(value, total, preview_image): + def hook(value, total, preview_image, prompt_id=None, node_id=None): + executing_context = get_executing_context() + if prompt_id is None and executing_context is not None: + prompt_id = executing_context.prompt_id + if node_id is None and executing_context is not None: + node_id = executing_context.node_id comfy.model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} + if prompt_id is None: + prompt_id = server_instance.last_prompt_id + if node_id is None: + node_id = server_instance.last_node_id + progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id} + get_progress_state().update_progress(node_id, value, total, preview_image) server_instance.send_sync("progress", progress, server_instance.client_id) if preview_image is not None: - server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) + # Only send old method if client doesn't support preview metadata + if not feature_flags.supports_feature( + server_instance.sockets_metadata, + server_instance.client_id, + "supports_preview_metadata", + ): + server_instance.send_sync( + BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, + preview_image, + server_instance.client_id, + ) comfy.utils.set_progress_bar_global_hook(hook) diff --git a/protocol.py b/protocol.py new file mode 100644 index 000000000..038a0a840 --- /dev/null +++ b/protocol.py @@ -0,0 +1,7 @@ + +class BinaryEventTypes: + PREVIEW_IMAGE = 1 + UNENCODED_PREVIEW_IMAGE = 2 + TEXT = 3 + PREVIEW_IMAGE_WITH_METADATA = 4 + diff --git a/server.py b/server.py index 878b5eeb1..e8bad9f4e 100644 --- a/server.py +++ b/server.py @@ -26,6 +26,7 @@ import mimetypes from comfy.cli_args import args import comfy.utils import comfy.model_management +from comfy_api import feature_flags import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager @@ -35,11 +36,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes - -class BinaryEventTypes: - PREVIEW_IMAGE = 1 - UNENCODED_PREVIEW_IMAGE = 2 - TEXT = 3 +from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): try: @@ -178,6 +175,7 @@ class PromptServer(): max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() + self.sockets_metadata = dict() self.web_root = ( FrontendManager.init_frontend(args.front_end_version) if args.front_end_root is None @@ -202,20 +200,53 @@ class PromptServer(): else: sid = uuid.uuid4().hex + # Store WebSocket for backward compatibility self.sockets[sid] = ws + # Store metadata separately + self.sockets_metadata[sid] = {"feature_flags": {}} try: # Send initial state to the new client - await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid) + await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: await self.send("executing", { "node": self.last_node_id }, sid) + # Flag to track if we've received the first message + first_message = True + async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: logging.warning('ws connection closed with exception %s' % ws.exception()) + elif msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + # Check if first message is feature flags + if first_message and data.get("type") == "feature_flags": + # Store client feature flags + client_flags = data.get("data", {}) + self.sockets_metadata[sid]["feature_flags"] = client_flags + + # Send server feature flags in response + await self.send( + "feature_flags", + feature_flags.get_server_features(), + sid, + ) + + logging.info( + f"Feature flags negotiated for client {sid}: {client_flags}" + ) + first_message = False + except json.JSONDecodeError: + logging.warning( + f"Invalid JSON received from client {sid}: {msg.data}" + ) + except Exception as e: + logging.error(f"Error processing WebSocket message: {e}") finally: self.sockets.pop(sid, None) + self.sockets_metadata.pop(sid, None) return ws @routes.get("/") @@ -548,6 +579,10 @@ class PromptServer(): } return web.json_response(system_stats) + @routes.get("/features") + async def get_features(request): + return web.json_response(feature_flags.get_server_features()) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) @@ -643,7 +678,8 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] - valid = execution.validate_prompt(prompt) + prompt_id = str(uuid.uuid4()) + valid = await execution.validate_prompt(prompt_id, prompt) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] @@ -651,7 +687,6 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} @@ -766,6 +801,10 @@ class PromptServer(): async def send(self, event, data, sid=None): if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: await self.send_image(data, sid=sid) + elif event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: + # data is (preview_image, metadata) + preview_image, metadata = data + await self.send_image_with_metadata(preview_image, metadata, sid=sid) elif isinstance(data, (bytes, bytearray)): await self.send_bytes(event, data, sid) else: @@ -804,6 +843,43 @@ class PromptServer(): preview_bytes = bytesIO.getvalue() await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) + async def send_image_with_metadata(self, image_data, metadata=None, sid=None): + image_type = image_data[0] + image = image_data[1] + max_size = image_data[2] + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.Resampling.LANCZOS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + + mimetype = "image/png" if image_type == "PNG" else "image/jpeg" + + # Prepare metadata + if metadata is None: + metadata = {} + metadata["image_type"] = mimetype + + # Serialize metadata as JSON + import json + metadata_json = json.dumps(metadata).encode('utf-8') + metadata_length = len(metadata_json) + + # Prepare image data + bytesIO = BytesIO() + image.save(bytesIO, format=image_type, quality=95, compress_level=1) + image_bytes = bytesIO.getvalue() + + # Combine metadata and image + combined_data = bytearray() + combined_data.extend(struct.pack(">I", metadata_length)) + combined_data.extend(metadata_json) + combined_data.extend(image_bytes) + + await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, combined_data, sid=sid) + async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) @@ -845,10 +921,10 @@ class PromptServer(): ssl_ctx = None scheme = "http" if args.tls_keyfile and args.tls_certfile: - ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) - ssl_ctx.load_cert_chain(certfile=args.tls_certfile, + ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) + ssl_ctx.load_cert_chain(certfile=args.tls_certfile, keyfile=args.tls_keyfile) - scheme = "https" + scheme = "https" if verbose: logging.info("Starting server\n") diff --git a/tests-unit/feature_flags_test.py b/tests-unit/feature_flags_test.py new file mode 100644 index 000000000..f2702cfc8 --- /dev/null +++ b/tests-unit/feature_flags_test.py @@ -0,0 +1,98 @@ +"""Tests for feature flags functionality.""" + +from comfy_api.feature_flags import ( + get_connection_feature, + supports_feature, + get_server_features, + SERVER_FEATURE_FLAGS, +) + + +class TestFeatureFlags: + """Test suite for feature flags functions.""" + + def test_get_server_features_returns_copy(self): + """Test that get_server_features returns a copy of the server flags.""" + features = get_server_features() + # Verify it's a copy by modifying it + features["test_flag"] = True + # Original should be unchanged + assert "test_flag" not in SERVER_FEATURE_FLAGS + + def test_get_server_features_contains_expected_flags(self): + """Test that server features contain expected flags.""" + features = get_server_features() + assert "supports_preview_metadata" in features + assert features["supports_preview_metadata"] is True + assert "max_upload_size" in features + assert isinstance(features["max_upload_size"], (int, float)) + + def test_get_connection_feature_with_missing_sid(self): + """Test getting feature for non-existent session ID.""" + sockets_metadata = {} + result = get_connection_feature(sockets_metadata, "missing_sid", "some_feature") + assert result is False # Default value + + def test_get_connection_feature_with_custom_default(self): + """Test getting feature with custom default value.""" + sockets_metadata = {} + result = get_connection_feature( + sockets_metadata, "missing_sid", "some_feature", default="custom_default" + ) + assert result == "custom_default" + + def test_get_connection_feature_with_feature_flags(self): + """Test getting feature from connection with feature flags.""" + sockets_metadata = { + "sid1": { + "feature_flags": { + "supports_preview_metadata": True, + "custom_feature": "value", + }, + } + } + result = get_connection_feature(sockets_metadata, "sid1", "supports_preview_metadata") + assert result is True + + result = get_connection_feature(sockets_metadata, "sid1", "custom_feature") + assert result == "value" + + def test_get_connection_feature_missing_feature(self): + """Test getting non-existent feature from connection.""" + sockets_metadata = { + "sid1": {"feature_flags": {"existing_feature": True}} + } + result = get_connection_feature(sockets_metadata, "sid1", "missing_feature") + assert result is False + + def test_supports_feature_returns_boolean(self): + """Test that supports_feature always returns boolean.""" + sockets_metadata = { + "sid1": { + "feature_flags": { + "bool_feature": True, + "string_feature": "value", + "none_feature": None, + }, + } + } + + # True boolean feature + assert supports_feature(sockets_metadata, "sid1", "bool_feature") is True + + # Non-boolean values should return False + assert supports_feature(sockets_metadata, "sid1", "string_feature") is False + assert supports_feature(sockets_metadata, "sid1", "none_feature") is False + assert supports_feature(sockets_metadata, "sid1", "missing_feature") is False + + def test_supports_feature_with_missing_connection(self): + """Test supports_feature with missing connection.""" + sockets_metadata = {} + assert supports_feature(sockets_metadata, "missing_sid", "any_feature") is False + + def test_empty_feature_flags_dict(self): + """Test connection with empty feature flags dictionary.""" + sockets_metadata = {"sid1": {"feature_flags": {}}} + result = get_connection_feature(sockets_metadata, "sid1", "any_feature") + assert result is False + assert supports_feature(sockets_metadata, "sid1", "any_feature") is False diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index d70d00f4b..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -1,3 +1,4 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio +websocket-client diff --git a/tests-unit/websocket_feature_flags_test.py b/tests-unit/websocket_feature_flags_test.py new file mode 100644 index 000000000..e93b2e1dd --- /dev/null +++ b/tests-unit/websocket_feature_flags_test.py @@ -0,0 +1,77 @@ +"""Simplified tests for WebSocket feature flags functionality.""" +from comfy_api import feature_flags + + +class TestWebSocketFeatureFlags: + """Test suite for WebSocket feature flags integration.""" + + def test_server_feature_flags_response(self): + """Test server feature flags are properly formatted.""" + features = feature_flags.get_server_features() + + # Check expected server features + assert "supports_preview_metadata" in features + assert features["supports_preview_metadata"] is True + assert "max_upload_size" in features + assert isinstance(features["max_upload_size"], (int, float)) + + def test_progress_py_checks_feature_flags(self): + """Test that progress.py checks feature flags before sending metadata.""" + # This simulates the check in progress.py + client_id = "test_client" + sockets_metadata = {"test_client": {"feature_flags": {}}} + + # The actual check would be in progress.py + supports_metadata = feature_flags.supports_feature( + sockets_metadata, client_id, "supports_preview_metadata" + ) + + assert supports_metadata is False + + def test_multiple_clients_different_features(self): + """Test handling multiple clients with different feature support.""" + sockets_metadata = { + "modern_client": { + "feature_flags": {"supports_preview_metadata": True} + }, + "legacy_client": { + "feature_flags": {} + } + } + + # Check modern client + assert feature_flags.supports_feature( + sockets_metadata, "modern_client", "supports_preview_metadata" + ) is True + + # Check legacy client + assert feature_flags.supports_feature( + sockets_metadata, "legacy_client", "supports_preview_metadata" + ) is False + + def test_feature_negotiation_message_format(self): + """Test the format of feature negotiation messages.""" + # Client message format + client_message = { + "type": "feature_flags", + "data": { + "supports_preview_metadata": True, + "api_version": "1.0.0" + } + } + + # Verify structure + assert client_message["type"] == "feature_flags" + assert "supports_preview_metadata" in client_message["data"] + + # Server response format (what would be sent) + server_features = feature_flags.get_server_features() + server_message = { + "type": "feature_flags", + "data": server_features + } + + # Verify structure + assert server_message["type"] == "feature_flags" + assert "supports_preview_metadata" in server_message["data"] + assert server_message["data"]["supports_preview_metadata"] is True diff --git a/tests/inference/extra_model_paths.yaml b/tests/inference/extra_model_paths.yaml index 75b2e1ae4..68e056564 100644 --- a/tests/inference/extra_model_paths.yaml +++ b/tests/inference/extra_model_paths.yaml @@ -1,4 +1,4 @@ # Config for testing nodes testing: - custom_nodes: tests/inference/testing_nodes + custom_nodes: testing_nodes diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py new file mode 100644 index 000000000..b243bbca9 --- /dev/null +++ b/tests/inference/test_async_nodes.py @@ -0,0 +1,410 @@ +import pytest +import time +import torch +import urllib.error +import numpy as np +import subprocess + +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.inference.test_execution import ComfyClient + + +@pytest.mark.execution +class TestAsyncNodes: + @fixture(scope="class", autouse=True, params=[ + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + # Running server with args: pargs + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + except ConnectionRefusedError: + # Retrying... + pass + else: + break + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"async_nodes[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + # Happy Path Tests + + def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test that a basic async node executes correctly.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) + output = g.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g) + + # Verify execution completed + assert result.did_run(sleep_node), "Async sleep node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify the image passed through correctly + result_images = result.get_images(output) + assert len(result_images) == 1, "Should have 1 image" + assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" + + def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test that multiple async nodes execute in parallel.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async sleep nodes with different durations + sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) + sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Add outputs for each + _output1 = g.node("PreviewImage", images=sleep1.out(0)) + _output2 = g.node("PreviewImage", images=sleep2.out(0)) + _output3 = g.node("PreviewImage", images=sleep3.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should take ~0.5s (max duration) not 1.2s (sum of durations) + assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" + + # Verify all nodes executed + assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) + + def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with proper dependency handling.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Chain of async operations + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Average depends on both async results + average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + + # Verify execution order + assert result.did_run(sleep1) and result.did_run(sleep2) + assert result.did_run(average) and result.did_run(output) + + # Verify averaged result + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" + + def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): + """Test async VALIDATE_INPUTS function.""" + g = builder + # Create a test node with async validation + validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + # Should pass validation + result = client.run(g) + assert result.did_run(validation_node) + + # Test validation failure + validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with lazy evaluation.""" + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + # Create async nodes that will be evaluated lazily + sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) + sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) + + # Use lazy mix that only needs sleep1 (mask=0.0) + lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should only execute sleep1, not sleep2 + assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" + assert result.did_run(sleep1), "Sleep1 should have executed" + assert not result.did_run(sleep2), "Sleep2 should have been skipped" + + def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): + """Test async check_lazy_status function.""" + g = builder + # Create a node with async check_lazy_status + lazy_node = g.node("TestAsyncLazyCheck", + input1="value1", + input2="value2", + condition=True) + g.node("SaveImage", images=lazy_node.out(0)) + + result = client.run(g) + assert result.did_run(lazy_node) + + # Error Handling Tests + + def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): + """Test that async execution errors are properly handled.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Create an async node that will error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" + + def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): + """Test async validation error handling.""" + g = builder + # Node with async validation that will fail + validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) + g.node("SaveImage", images=validation_node.out(0)) + + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.run(g) + # Verify it's a validation error + assert exc_info.value.code == 400 + + def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): + """Test handling of async operations that timeout.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Very long sleep that would timeout + timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) + g.node("SaveImage", images=timeout_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised a timeout error" + except Exception as e: + assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" + + def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): + """Test that workflow can recover after async errors.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # First run with error + error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) + g.node("SaveImage", images=error_node.out(0)) + + try: + client.run(g) + except Exception: + pass # Expected + + # Second run should succeed + g2 = GraphBuilder(prefix="recovery_test") + image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) + g2.node("SaveImage", images=sleep_node.out(0)) + + result = client.run(g2) + assert result.did_run(sleep_node), "Should be able to run after error" + + def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test handling when sync node errors while async node is executing.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Async node that takes time + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) + + # Sync node that will error immediately + error_node = g.node("TestSyncError", value=image.out(0)) + + # Both feed into output + g.node("PreviewImage", images=sleep_node.out(0)) + g.node("PreviewImage", images=error_node.out(0)) + + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + # Verify the sync error was caught even though async was running + assert 'prompt_id' in e.args[0] + + # Edge Cases + + def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes with execution blockers.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Async sleep nodes + sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) + sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) + + # Create list of images + image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) + + # Create list of blocking conditions - [False, True] to block only the second item + int1 = g.node("StubInt", value=1) + int2 = g.node("StubInt", value=2) + block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) + + # Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2) + compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") + + # Block based on the comparison results + blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) + + output = g.node("PreviewImage", images=blocker.out(0)) + + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have blocked second image" + + def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): + """Test that async nodes are properly cached.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) + g.node("SaveImage", images=sleep_node.out(0)) + + # First run + result1 = client.run(g) + assert result1.did_run(sleep_node), "Should run first time" + + # Second run - should be cached + start_time = time.time() + result2 = client.run(g) + elapsed_time = time.time() - start_time + + assert not result2.did_run(sleep_node), "Should be cached" + assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" + + def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): + """Test async nodes within dynamically generated prompts.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Node that generates async nodes dynamically + dynamic_async = g.node("TestDynamicAsyncGeneration", + image1=image1.out(0), + image2=image2.out(0), + num_async_nodes=3, + sleep_duration=0.2) + g.node("SaveImage", images=dynamic_async.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Should execute async nodes in parallel within dynamic prompt + assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" + assert result.did_run(dynamic_async) + + def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): + """Test that async resources are properly cleaned up.""" + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create multiple async nodes that use resources + resource_nodes = [] + for i in range(5): + node = g.node("TestAsyncResourceUser", + value=image.out(0), + resource_id=f"resource_{i}", + duration=0.1) + resource_nodes.append(node) + g.node("PreviewImage", images=node.out(0)) + + result = client.run(g) + + # Verify all nodes executed + for node in resource_nodes: + assert result.did_run(node) + + # Run again to ensure resources were cleaned up + result2 = client.run(g) + # Should be cached but not error due to resource conflicts + for node in resource_nodes: + assert not result2.did_run(node), "Should be cached" + + def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): + """Test cancellation of async operations.""" + # This would require implementing cancellation in the client + # For now, we'll test that long-running async operations can be interrupted + pass # TODO: Implement when cancellation API is available + + def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): + """Test workflows with both sync and async nodes.""" + g = builder + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + # Mix of sync and async operations + # Sync: lazy mix images + sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) + # Async: sleep + async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) + # Sync: custom validation + sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) + # Async: sleep again + async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) + + output = g.node("SaveImage", images=async_op2.out(0)) + + result = client.run(g) + + # Verify all nodes executed in correct order + assert result.did_run(sync_op1) + assert result.did_run(async_op1) + assert result.did_run(sync_op2) + assert result.did_run(async_op2) + + # Image should be a mix of black and white (gray) + result_images = result.get_images(output) + avg_value = np.array(result_images[0]).mean() + assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75" diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 5cda5c1ae..9d3d685cc 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -252,7 +252,7 @@ class TestExecution: @pytest.mark.parametrize("test_type, test_value", [ ("StubInt", 5), - ("StubFloat", 5.0) + ("StubMask", 5.0) ]) def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): g = builder @@ -497,6 +497,69 @@ class TestExecution: assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert not result.did_run(test_node), "The execution should have been cached" + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create sleep nodes for each duration + sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) + sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) + + # Add outputs to verify the execution + _output1 = g.node("PreviewImage", images=sleep_node1.out(0)) + _output2 = g.node("PreviewImage", images=sleep_node2.out(0)) + _output3 = g.node("PreviewImage", images=sleep_node3.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # The test should take around 0.4 seconds (the longest sleep duration) + # plus some overhead, but definitely less than the sum of all sleeps (0.9s) + # We'll allow for up to 0.8s total to account for overhead + assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s" + + # Verify that all nodes executed + assert result.did_run(sleep_node1), "Sleep node 1 should have run" + assert result.did_run(sleep_node2), "Sleep node 2 should have run" + assert result.did_run(sleep_node3), "Sleep node 3 should have run" + + def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + g = builder + # Create input images with different values + image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + image3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a TestParallelSleep node that expands into multiple TestSleep nodes + parallel_sleep = g.node("TestParallelSleep", + image1=image1.out(0), + image2=image2.out(0), + image3=image3.out(0), + sleep1=0.4, + sleep2=0.5, + sleep3=0.6) + output = g.node("SaveImage", images=parallel_sleep.out(0)) + + start_time = time.time() + result = client.run(g) + elapsed_time = time.time() - start_time + + # Similar to the previous test, expect parallel execution of the sleep nodes + # which should complete in less than the sum of all sleeps + assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" + + # Verify the parallel sleep node executed + assert result.did_run(parallel_sleep), "ParallelSleep node should have run" + + # Verify we get an image as output (blend of the three input images) + result_images = result.get_images(output) + assert len(result_images) == 1, "Should have 1 image" + # Average pixel value should be around 170 (255 * 2 // 3) + avg_value = numpy.array(result_images[0]).mean() + assert avg_value == 170, f"Image average value {avg_value} should be 170" + # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, # only that one entry in the list is blocked. diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py index dcc71659a..20f9533c7 100644 --- a/tests/inference/testing_nodes/testing-pack/__init__.py +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -1,23 +1,26 @@ -from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS -from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS -from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS -from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS -from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS - -# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) -# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) - -NODE_CLASS_MAPPINGS = {} -NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) -NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) - -NODE_DISPLAY_NAME_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) - +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS +from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS) + diff --git a/tests/inference/testing_nodes/testing-pack/async_test_nodes.py b/tests/inference/testing_nodes/testing-pack/async_test_nodes.py new file mode 100644 index 000000000..547eea6f4 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/async_test_nodes.py @@ -0,0 +1,343 @@ +import torch +import asyncio +from typing import Dict +from comfy.utils import ProgressBar +from comfy_execution.graph_utils import GraphBuilder +from comfy.comfy_types.node_typing import ComfyNodeABC +from comfy.comfy_types import IO + + +class TestAsyncValidation(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "threshold": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, threshold): + # Simulate async validation (e.g., checking remote service) + await asyncio.sleep(0.05) + + if value > threshold: + return f"Value {value} exceeds threshold {threshold}" + return True + + def process(self, value, threshold): + # Create image based on value + intensity = value / 10.0 + image = torch.ones([1, 512, 512, 3]) * intensity + return (image,) + + +class TestAsyncError(ComfyNodeABC): + """Test node that errors during async execution.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "error_execution" + CATEGORY = "_for_testing/async" + + async def error_execution(self, value, error_after): + await asyncio.sleep(error_after) + raise RuntimeError("Intentional async execution error for testing") + + +class TestAsyncValidationError(ComfyNodeABC): + """Test node with async validation that always fails.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 5.0}), + "max_value": ("FLOAT", {"default": 10.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + @classmethod + async def VALIDATE_INPUTS(cls, value, max_value): + await asyncio.sleep(0.05) + # Always fail validation for values > max_value + if value > max_value: + return f"Async validation failed: {value} > {max_value}" + return True + + def process(self, value, max_value): + # This won't be reached if validation fails + image = torch.ones([1, 512, 512, 3]) * (value / max_value) + return (image,) + + +class TestAsyncTimeout(ComfyNodeABC): + """Test node that simulates timeout scenarios.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), + "operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "timeout_execution" + CATEGORY = "_for_testing/async" + + async def timeout_execution(self, value, timeout, operation_time): + try: + # This will timeout if operation_time > timeout + await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout) + return (value,) + except asyncio.TimeoutError: + raise RuntimeError(f"Operation timed out after {timeout} seconds") + + +class TestSyncError(ComfyNodeABC): + """Test node that errors synchronously (for mixed sync/async testing).""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sync_error" + CATEGORY = "_for_testing/async" + + def sync_error(self, value): + raise RuntimeError("Intentional sync execution error for testing") + + +class TestAsyncLazyCheck(ComfyNodeABC): + """Test node with async check_lazy_status.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": (IO.ANY, {"lazy": True}), + "input2": (IO.ANY, {"lazy": True}), + "condition": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing/async" + + async def check_lazy_status(self, condition, input1, input2): + # Simulate async checking (e.g., querying remote service) + await asyncio.sleep(0.05) + + needed = [] + if condition and input1 is None: + needed.append("input1") + if not condition and input2 is None: + needed.append("input2") + return needed + + def process(self, input1, input2, condition): + # Return a simple image + return (torch.ones([1, 512, 512, 3]),) + + +class TestDynamicAsyncGeneration(ComfyNodeABC): + """Test node that dynamically generates async nodes.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}), + "sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "generate_async_workflow" + CATEGORY = "_for_testing/async" + + def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): + g = GraphBuilder() + + # Create multiple async sleep nodes + sleep_nodes = [] + for i in range(num_async_nodes): + image = image1 if i % 2 == 0 else image2 + sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration) + sleep_nodes.append(sleep_node) + + # Average all results + if len(sleep_nodes) == 1: + final_node = sleep_nodes[0] + else: + avg_inputs = {"input1": sleep_nodes[0].out(0)} + for i, node in enumerate(sleep_nodes[1:], 2): + avg_inputs[f"input{i}"] = node.out(0) + final_node = g.node("TestVariadicAverage", **avg_inputs) + + return { + "result": (final_node.out(0),), + "expand": g.finalize(), + } + + +class TestAsyncResourceUser(ComfyNodeABC): + """Test node that uses resources during async execution.""" + + # Class-level resource tracking for testing + _active_resources: Dict[str, bool] = {} + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "resource_id": ("STRING", {"default": "resource_0"}), + "duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "use_resource" + CATEGORY = "_for_testing/async" + + async def use_resource(self, value, resource_id, duration): + # Check if resource is already in use + if self._active_resources.get(resource_id, False): + raise RuntimeError(f"Resource {resource_id} is already in use!") + + # Mark resource as in use + self._active_resources[resource_id] = True + + try: + # Simulate resource usage + await asyncio.sleep(duration) + return (value,) + finally: + # Always clean up resource + self._active_resources[resource_id] = False + + +class TestAsyncBatchProcessing(ComfyNodeABC): + """Test async processing of batched inputs.""" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": ("IMAGE",), + "process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process_batch" + CATEGORY = "_for_testing/async" + + async def process_batch(self, images, process_time_per_item, unique_id): + batch_size = images.shape[0] + pbar = ProgressBar(batch_size, node_id=unique_id) + + # Process each image in the batch + processed = [] + for i in range(batch_size): + # Simulate async processing + await asyncio.sleep(process_time_per_item) + + # Simple processing: invert the image + processed_image = 1.0 - images[i:i+1] + processed.append(processed_image) + + pbar.update(1) + + # Stack processed images + result = torch.cat(processed, dim=0) + return (result,) + + +class TestAsyncConcurrentLimit(ComfyNodeABC): + """Test concurrent execution limits for async nodes.""" + + _semaphore = asyncio.Semaphore(2) # Only allow 2 concurrent executions + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}), + "node_id": ("INT", {"default": 0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "limited_execution" + CATEGORY = "_for_testing/async" + + async def limited_execution(self, value, duration, node_id): + async with self._semaphore: + # Node {node_id} acquired semaphore + await asyncio.sleep(duration) + # Node {node_id} releasing semaphore + return (value,) + + +# Add node mappings +ASYNC_TEST_NODE_CLASS_MAPPINGS = { + "TestAsyncValidation": TestAsyncValidation, + "TestAsyncError": TestAsyncError, + "TestAsyncValidationError": TestAsyncValidationError, + "TestAsyncTimeout": TestAsyncTimeout, + "TestSyncError": TestSyncError, + "TestAsyncLazyCheck": TestAsyncLazyCheck, + "TestDynamicAsyncGeneration": TestDynamicAsyncGeneration, + "TestAsyncResourceUser": TestAsyncResourceUser, + "TestAsyncBatchProcessing": TestAsyncBatchProcessing, + "TestAsyncConcurrentLimit": TestAsyncConcurrentLimit, +} + +ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAsyncValidation": "Test Async Validation", + "TestAsyncError": "Test Async Error", + "TestAsyncValidationError": "Test Async Validation Error", + "TestAsyncTimeout": "Test Async Timeout", + "TestSyncError": "Test Sync Error", + "TestAsyncLazyCheck": "Test Async Lazy Check", + "TestDynamicAsyncGeneration": "Test Dynamic Async Generation", + "TestAsyncResourceUser": "Test Async Resource User", + "TestAsyncBatchProcessing": "Test Async Batch Processing", + "TestAsyncConcurrentLimit": "Test Async Concurrent Limit", +} diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 9d05ab14f..657d49f2f 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,6 +1,11 @@ import torch +import time +import asyncio +from comfy.utils import ProgressBar from .tools import VariantSupport from comfy_execution.graph_utils import GraphBuilder +from comfy.comfy_types.node_typing import ComfyNodeABC +from comfy.comfy_types import IO class TestLazyMixImages: @classmethod @@ -333,6 +338,131 @@ class TestMixedExpansionReturns: "expand": g.finalize(), } +class TestSamplingInExpansion: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 100}), + "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0}), + "prompt": ("STRING", {"multiline": True, "default": "a beautiful landscape with mountains and trees"}), + "negative_prompt": ("STRING", {"multiline": True, "default": "blurry, bad quality, worst quality"}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sampling_in_expansion" + + CATEGORY = "Testing/Nodes" + + def sampling_in_expansion(self, model, clip, vae, seed, steps, cfg, prompt, negative_prompt): + g = GraphBuilder() + + # Create a basic image generation workflow using the input model, clip and vae + # 1. Setup text prompts using the provided CLIP model + positive_prompt = g.node("CLIPTextEncode", + text=prompt, + clip=clip) + negative_prompt = g.node("CLIPTextEncode", + text=negative_prompt, + clip=clip) + + # 2. Create empty latent with specified size + empty_latent = g.node("EmptyLatentImage", width=512, height=512, batch_size=1) + + # 3. Setup sampler and generate image latent + sampler = g.node("KSampler", + model=model, + positive=positive_prompt.out(0), + negative=negative_prompt.out(0), + latent_image=empty_latent.out(0), + seed=seed, + steps=steps, + cfg=cfg, + sampler_name="euler_ancestral", + scheduler="normal") + + # 4. Decode latent to image using VAE + output = g.node("VAEDecode", samples=sampler.out(0), vae=vae) + + return { + "result": (output.out(0),), + "expand": g.finalize(), + } + +class TestSleep(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + RETURN_TYPES = (IO.ANY,) + FUNCTION = "sleep" + + CATEGORY = "_for_testing" + + async def sleep(self, value, seconds, unique_id): + pbar = ProgressBar(seconds, node_id=unique_id) + start = time.time() + expiration = start + seconds + now = start + while now < expiration: + now = time.time() + pbar.update_absolute(now - start) + await asyncio.sleep(0.01) + return (value,) + +class TestParallelSleep(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE", ), + "image2": ("IMAGE", ), + "image3": ("IMAGE", ), + "sleep1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "sleep2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "sleep3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "parallel_sleep" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id): + # Create a graph dynamically with three TestSleep nodes + g = GraphBuilder() + + # Create sleep nodes for each duration and image + sleep_node1 = g.node("TestSleep", value=image1, seconds=sleep1) + sleep_node2 = g.node("TestSleep", value=image2, seconds=sleep2) + sleep_node3 = g.node("TestSleep", value=image3, seconds=sleep3) + + # Blend the results using TestVariadicAverage + blend = g.node("TestVariadicAverage", + input1=sleep_node1.out(0), + input2=sleep_node2.out(0), + input3=sleep_node3.out(0)) + + return { + "result": (blend.out(0),), + "expand": g.finalize(), + } + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -345,6 +475,9 @@ TEST_NODE_CLASS_MAPPINGS = { "TestCustomValidation5": TestCustomValidation5, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestMixedExpansionReturns": TestMixedExpansionReturns, + "TestSamplingInExpansion": TestSamplingInExpansion, + "TestSleep": TestSleep, + "TestParallelSleep": TestParallelSleep, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -359,4 +492,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestCustomValidation5": "Custom Validation 5", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestMixedExpansionReturns": "Mixed Expansion Returns", + "TestSamplingInExpansion": "Sampling In Expansion", + "TestSleep": "Test Sleep", + "TestParallelSleep": "Test Parallel Sleep", }