diff --git a/apiconnector/apiconnector.py b/apiconnector/apiconnector.py new file mode 100644 index 000000000..7016b3d3e --- /dev/null +++ b/apiconnector/apiconnector.py @@ -0,0 +1,284 @@ +import uuid +import json +import urllib.request +import urllib.parse +from PIL import Image +from websocket import WebSocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import io +import requests +import time +import os +import subprocess + +APP_NAME = os.getenv('APP_NAME') # Name of the application +API_COMMAND_LINE = os.getenv('API_COMMAND_LINE') # Command line to start the API server +API_URL = os.getenv('API_URL') # URL of the API server +INSTANCE_IDENTIFIER = APP_NAME+'-'+str(uuid.uuid4()) # Unique identifier for this instance of the worker +TEST_PAYLOAD = os.getenv('TEST_PAYLOAD') + +class ComfyConnector: + _instance = None + _process = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(ComfyConnector, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, 'initialized'): + self.server_address = API_URL + self.client_id = INSTANCE_IDENTIFIER + self.ws = WebSocket() + self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}") + self.initialized = True + self.start_api() + + def start_api(self): # This method is used to start the API server + api_command_line = API_COMMAND_LINE + if self._process is None or self._process.poll() is not None: # Check if the process is not running or has terminated for some reason + self._process = subprocess.Popen(api_command_line.split()) + print("API process started with PID:", self._process.pid) + while not self.is_api_running(): # Block execution until the API server is running + time.sleep(0.5) # Wait for 0.5 seconds before checking again + + def is_api_running(self): # This method is used to check if the API server is running + test_payload = TEST_PAYLOAD + try: + response = requests.get(API_URL) + if response.status_code == 200: # Check if the API server tells us it's running by returning a 200 status code + test_image = self.generate_images(payload=test_payload) + if test_image: # this ensures that the API server is actually running and not just the web server + return True + return False + except Exception as e: + print("API not running:", e) + return False + + def kill_api(self): # This method is used to kill the API server + if self._process is not None and self._process.poll() is None: + self._process.kill() + self._process = None + print("API process killed") + + def get_history(self, prompt_id): # This method is used to retrieve the history of a prompt from the API server + with urllib.request.urlopen(f"http://{self.server_address}/history/{prompt_id}") as response: + return json.loads(response.read()) + + def get_image(self, filename, subfolder, folder_type): # This method is used to retrieve an image from the API server + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen(f"http://{self.server_address}/view?{url_values}") as response: + return response.read() + + def queue_prompt(self, prompt): # This method is used to queue a prompt for execution + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + print("Sending data:", data) # Print the data for debugging + headers = {'Content-Type': 'application/json'} # Set Content-Type header + req = urllib.request.Request(f"http://{self.server_address}/prompt", data=data, headers=headers) + return json.loads(urllib.request.urlopen(req).read()) + + def generate_images(self, prompt): # This method is used to generate images from a prompt and is the main method of this class + prompt_id = self.queue_prompt(prompt)['prompt_id'] + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break + address = self.find_output_node(prompt) + history = self.get_history(prompt_id)[prompt_id] + filenames = eval(f"history['outputs']{address}")['images'] # Extract all images + images = [] + for img_info in filenames: + filename = img_info['filename'] + subfolder = img_info['subfolder'] + folder_type = img_info['type'] + image_data = self.get_image(filename, subfolder, folder_type) + image_file = io.BytesIO(image_data) + image = Image.open(image_file) + images.append(image) + return images + + def upload_image(self, filepath, subfolder=None, folder_type=None, overwrite=False): # This method is used to upload an image to the API server for use in img2img or controlnet + url = f"http://{self.server_address}/upload/image" + files = {'image': open(filepath, 'rb')} + data = { + 'overwrite': str(overwrite).lower() + } + if subfolder: + data['subfolder'] = subfolder + if folder_type: + data['type'] = folder_type + response = requests.post(url, files=files, data=data) + return response.json() + + @staticmethod + def find_output_node(json_object): # This method is used to find the node containing the SaveImage class in a prompt + for key, value in json_object.items(): + if isinstance(value, dict): + if value.get("class_type") == "SaveImage": + return f"['{key}']" # Return the key containing the SaveImage class + result = ComfyConnector.find_output_node(value) + if result: + return result + return None + + @staticmethod + def replace_key_value(json_object, target_key, new_value, class_type_list=None, exclude=True): # This method is used to edit the payload of a prompt + for key, value in json_object.items(): + # Check if the current value is a dictionary and apply the logic recursively + if isinstance(value, dict): + class_type = value.get('class_type') + # Determine whether to apply the logic based on exclude and class_type_list + should_apply_logic = ( + (exclude and (class_type_list is None or class_type not in class_type_list)) or + (not exclude and (class_type_list is not None and class_type in class_type_list)) + ) + # Apply the logic to replace the target key with the new value if conditions are met + if should_apply_logic and target_key in value: + value[target_key] = new_value + # Recurse vertically (into nested dictionaries) + ComfyConnector.replace_key_value(value, target_key, new_value, class_type_list, exclude) + # Recurse sideways (into lists) + if isinstance(value, list): + for item in value: + if isinstance(item, dict): + ComfyConnector.replace_key_value(item, target_key, new_value, class_type_list, exclude) + + +###### test code below: + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "1_bloody_mary_v1.safetensors" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "9": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) + + + +# Instantiate APISingleton +api_singleton = ComfyConnector() + +# Test upload_image method +upload_result = api_singleton.upload_image('low_res_cat.png', folder_type='input') +print("Upload result:", upload_result) + +# Test queue_prompt method +prompt_result = api_singleton.queue_prompt(prompt) +print("Prompt result:", prompt_result) + +# Get the prompt ID +prompt_id = prompt_result['prompt_id'] + +# Test get_history method with the prompt ID +history_result = api_singleton.get_history(prompt_id) +print("History result:", history_result) + +# Test generate_images method with the sample prompt +image_file = api_singleton.generate_images(prompt) +image = Image.open(image_file) +image_path = "image.png" # Modify this path as needed +image.save(image_path) +print(f"Image saved to {image_path}") + +# Assuming that the generated image filename, subfolder, and folder_type are known +filename = "image_00001_.png" +subfolder = "" +folder_type = "output" + +# Test get_image method with the known details +image_data = api_singleton.get_image(filename, subfolder, folder_type) +image_file = io.BytesIO(image_data) +image = Image.open(image_file) +image.show() # Display the retrieved image + +print("All methods tested successfully.") \ No newline at end of file