diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 69b889c80..5d1c8f8cf 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -1,9 +1,9 @@ # Define a class for your command-line arguments import enum -from typing import Optional, List, TypeAlias, Callable +from typing import Optional, List, Callable import configargparse as argparse -ConfigurationExtender: TypeAlias = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]] +ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]] class LatentPreviewMethod(enum.Enum): diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 5a99a6e9e..923e40b71 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -205,6 +205,28 @@ class SamplerDPMPP_3M_SDE: sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) return (sampler, ) +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + class SamplerDPMPP_2M_SDE: @classmethod def INPUT_TYPES(s): diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py similarity index 84% rename from custom_nodes/websocket_image_save.py.disabled rename to custom_nodes/websocket_image_save.py index b85a5de8b..5aa573642 100644 --- a/custom_nodes/websocket_image_save.py.disabled +++ b/custom_nodes/websocket_image_save.py @@ -10,10 +10,6 @@ import time #binary images on the websocket with a 8 byte header indicating the type #of binary message (first 4 bytes) and the image format (next 4 bytes). -#The reason this node is disabled by default is because there is a small -#issue when using it with the default ComfyUI web interface: When generating -#batches only the last image will be shown in the UI. - #Note that no metadata will be put in the images saved with this node. class SaveImageWebsocket: @@ -28,7 +24,7 @@ class SaveImageWebsocket: OUTPUT_NODE = True - CATEGORY = "image" + CATEGORY = "api/image" def save_images(self, images): pbar = comfy.utils.ProgressBar(images.shape[0]) diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py new file mode 100644 index 000000000..737488621 --- /dev/null +++ b/script_examples/websockets_api_example_ws_images.py @@ -0,0 +1,159 @@ +#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without +#them being saved to disk + +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import json +import urllib.request +import urllib.parse + +server_address = "127.0.0.1:8188" +client_id = str(uuid.uuid4()) + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: + return response.read() + +def get_history(prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: + return json.loads(response.read()) + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)['prompt_id'] + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] == prompt_id: + if data['node'] is None: + break #Execution is done + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.ckpt" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) +#set the text prompt for our positive CLIPTextEncode +prompt["6"]["inputs"]["text"] = "masterpiece best quality man" + +#set the seed for our KSampler node +prompt["3"]["inputs"]["seed"] = 5 + +ws = websocket.WebSocket() +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) +images = get_images(ws, prompt) + +#Commented out code to display the output images: + +# for node_id in images: +# for image_data in images[node_id]: +# from PIL import Image +# import io +# image = Image.open(io.BytesIO(image_data)) +# image.show() +