ComfyUI/apiconnector/apiconnector.py
2023-08-14 10:58:33 -04:00

284 lines
10 KiB
Python

import uuid
import json
import urllib.request
import urllib.parse
from PIL import Image
from websocket import WebSocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import io
import requests
import time
import os
import subprocess
APP_NAME = os.getenv('APP_NAME') # Name of the application
API_COMMAND_LINE = os.getenv('API_COMMAND_LINE') # Command line to start the API server
API_URL = os.getenv('API_URL') # URL of the API server
INSTANCE_IDENTIFIER = APP_NAME+'-'+str(uuid.uuid4()) # Unique identifier for this instance of the worker
TEST_PAYLOAD = os.getenv('TEST_PAYLOAD')
class ComfyConnector:
_instance = None
_process = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(ComfyConnector, cls).__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.server_address = API_URL
self.client_id = INSTANCE_IDENTIFIER
self.ws = WebSocket()
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
self.initialized = True
self.start_api()
def start_api(self): # This method is used to start the API server
api_command_line = API_COMMAND_LINE
if self._process is None or self._process.poll() is not None: # Check if the process is not running or has terminated for some reason
self._process = subprocess.Popen(api_command_line.split())
print("API process started with PID:", self._process.pid)
while not self.is_api_running(): # Block execution until the API server is running
time.sleep(0.5) # Wait for 0.5 seconds before checking again
def is_api_running(self): # This method is used to check if the API server is running
test_payload = TEST_PAYLOAD
try:
response = requests.get(API_URL)
if response.status_code == 200: # Check if the API server tells us it's running by returning a 200 status code
test_image = self.generate_images(payload=test_payload)
if test_image: # this ensures that the API server is actually running and not just the web server
return True
return False
except Exception as e:
print("API not running:", e)
return False
def kill_api(self): # This method is used to kill the API server
if self._process is not None and self._process.poll() is None:
self._process.kill()
self._process = None
print("API process killed")
def get_history(self, prompt_id): # This method is used to retrieve the history of a prompt from the API server
with urllib.request.urlopen(f"http://{self.server_address}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_image(self, filename, subfolder, folder_type): # This method is used to retrieve an image from the API server
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"http://{self.server_address}/view?{url_values}") as response:
return response.read()
def queue_prompt(self, prompt): # This method is used to queue a prompt for execution
p = {"prompt": prompt, "client_id": self.client_id}
data = json.dumps(p).encode('utf-8')
print("Sending data:", data) # Print the data for debugging
headers = {'Content-Type': 'application/json'} # Set Content-Type header
req = urllib.request.Request(f"http://{self.server_address}/prompt", data=data, headers=headers)
return json.loads(urllib.request.urlopen(req).read())
def generate_images(self, prompt): # This method is used to generate images from a prompt and is the main method of this class
prompt_id = self.queue_prompt(prompt)['prompt_id']
while True:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break
address = self.find_output_node(prompt)
history = self.get_history(prompt_id)[prompt_id]
filenames = eval(f"history['outputs']{address}")['images'] # Extract all images
images = []
for img_info in filenames:
filename = img_info['filename']
subfolder = img_info['subfolder']
folder_type = img_info['type']
image_data = self.get_image(filename, subfolder, folder_type)
image_file = io.BytesIO(image_data)
image = Image.open(image_file)
images.append(image)
return images
def upload_image(self, filepath, subfolder=None, folder_type=None, overwrite=False): # This method is used to upload an image to the API server for use in img2img or controlnet
url = f"http://{self.server_address}/upload/image"
files = {'image': open(filepath, 'rb')}
data = {
'overwrite': str(overwrite).lower()
}
if subfolder:
data['subfolder'] = subfolder
if folder_type:
data['type'] = folder_type
response = requests.post(url, files=files, data=data)
return response.json()
@staticmethod
def find_output_node(json_object): # This method is used to find the node containing the SaveImage class in a prompt
for key, value in json_object.items():
if isinstance(value, dict):
if value.get("class_type") == "SaveImage":
return f"['{key}']" # Return the key containing the SaveImage class
result = ComfyConnector.find_output_node(value)
if result:
return result
return None
@staticmethod
def replace_key_value(json_object, target_key, new_value, class_type_list=None, exclude=True): # This method is used to edit the payload of a prompt
for key, value in json_object.items():
# Check if the current value is a dictionary and apply the logic recursively
if isinstance(value, dict):
class_type = value.get('class_type')
# Determine whether to apply the logic based on exclude and class_type_list
should_apply_logic = (
(exclude and (class_type_list is None or class_type not in class_type_list)) or
(not exclude and (class_type_list is not None and class_type in class_type_list))
)
# Apply the logic to replace the target key with the new value if conditions are met
if should_apply_logic and target_key in value:
value[target_key] = new_value
# Recurse vertically (into nested dictionaries)
ComfyConnector.replace_key_value(value, target_key, new_value, class_type_list, exclude)
# Recurse sideways (into lists)
if isinstance(value, list):
for item in value:
if isinstance(item, dict):
ComfyConnector.replace_key_value(item, target_key, new_value, class_type_list, exclude)
###### test code below:
prompt_text = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": 8566257,
"steps": 20
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "1_bloody_mary_v1.safetensors"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": 512,
"width": 512
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "bad hands"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
}
}
}
"""
prompt = json.loads(prompt_text)
# Instantiate APISingleton
api_singleton = ComfyConnector()
# Test upload_image method
upload_result = api_singleton.upload_image('low_res_cat.png', folder_type='input')
print("Upload result:", upload_result)
# Test queue_prompt method
prompt_result = api_singleton.queue_prompt(prompt)
print("Prompt result:", prompt_result)
# Get the prompt ID
prompt_id = prompt_result['prompt_id']
# Test get_history method with the prompt ID
history_result = api_singleton.get_history(prompt_id)
print("History result:", history_result)
# Test generate_images method with the sample prompt
image_file = api_singleton.generate_images(prompt)
image = Image.open(image_file)
image_path = "image.png" # Modify this path as needed
image.save(image_path)
print(f"Image saved to {image_path}")
# Assuming that the generated image filename, subfolder, and folder_type are known
filename = "image_00001_.png"
subfolder = ""
folder_type = "output"
# Test get_image method with the known details
image_data = api_singleton.get_image(filename, subfolder, folder_type)
image_file = io.BytesIO(image_data)
image = Image.open(image_file)
image.show() # Display the retrieved image
print("All methods tested successfully.")