ComfyUI/scripts/runpod_handler.py

1012 lines
46 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import runpod
from runpod.serverless.utils import rp_upload
import json
import urllib.request
import urllib.parse
import time
import os
import requests
import base64
from io import BytesIO
import websocket
import uuid
import tempfile
import socket
import traceback
# Time to wait between API check attempts in milliseconds
COMFY_API_AVAILABLE_INTERVAL_MS = 50
# Maximum number of API check attempts
COMFY_API_AVAILABLE_MAX_RETRIES = 500
# Websocket reconnection behaviour (can be overridden through environment variables)
# NOTE: more attempts and diagnostics improve debuggability whenever ComfyUI crashes mid-job.
# • WEBSOCKET_RECONNECT_ATTEMPTS sets how many times we will try to reconnect.
# • WEBSOCKET_RECONNECT_DELAY_S sets the sleep in seconds between attempts.
#
# If the respective env-vars are not supplied we fall back to sensible defaults ("5" and "3").
WEBSOCKET_RECONNECT_ATTEMPTS = int(os.environ.get("WEBSOCKET_RECONNECT_ATTEMPTS", 5))
WEBSOCKET_RECONNECT_DELAY_S = int(os.environ.get("WEBSOCKET_RECONNECT_DELAY_S", 3))
# Extra verbose websocket trace logs (set WEBSOCKET_TRACE=true to enable)
if os.environ.get("WEBSOCKET_TRACE", "false").lower() == "true":
# This prints low-level frame information to stdout which is invaluable for diagnosing
# protocol errors but can be noisy in production therefore gated behind an env-var.
websocket.enableTrace(True)
# Host where ComfyUI is running
COMFY_HOST = "127.0.0.1:8188"
# Enforce a clean state after each job is done
# see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker
REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true"
# ---------------------------------------------------------------------------
# Helper: quick reachability probe of ComfyUI HTTP endpoint (port 8188)
# ---------------------------------------------------------------------------
def _comfy_server_status():
"""Return a dictionary with basic reachability info for the ComfyUI HTTP server."""
try:
resp = requests.get(f"http://{COMFY_HOST}/", timeout=5)
return {
"reachable": resp.status_code == 200,
"status_code": resp.status_code,
}
except Exception as exc:
return {"reachable": False, "error": str(exc)}
def _attempt_websocket_reconnect(ws_url, max_attempts, delay_s, initial_error):
"""
Attempts to reconnect to the WebSocket server after a disconnect.
Args:
ws_url (str): The WebSocket URL (including client_id).
max_attempts (int): Maximum number of reconnection attempts.
delay_s (int): Delay in seconds between attempts.
initial_error (Exception): The error that triggered the reconnect attempt.
Returns:
websocket.WebSocket: The newly connected WebSocket object.
Raises:
websocket.WebSocketConnectionClosedException: If reconnection fails after all attempts.
"""
print(
f"worker-comfyui - Websocket connection closed unexpectedly: {initial_error}. Attempting to reconnect..."
)
last_reconnect_error = initial_error
for attempt in range(max_attempts):
# Log current server status before each reconnect attempt so that we can
# see whether ComfyUI is still alive (HTTP port 8188 responding) even if
# the websocket dropped. This is extremely useful to differentiate
# between a network glitch and an outright ComfyUI crash/OOM-kill.
srv_status = _comfy_server_status()
if not srv_status["reachable"]:
# If ComfyUI itself is down there is no point in retrying the websocket
# bail out immediately so the caller gets a clear "ComfyUI crashed" error.
print(
f"worker-comfyui - ComfyUI HTTP unreachable aborting websocket reconnect: {srv_status.get('error', 'status '+str(srv_status.get('status_code')))}"
)
raise websocket.WebSocketConnectionClosedException(
"ComfyUI HTTP unreachable during websocket reconnect"
)
# Otherwise we proceed with reconnect attempts while server is up
print(
f"worker-comfyui - Reconnect attempt {attempt + 1}/{max_attempts}... (ComfyUI HTTP reachable, status {srv_status.get('status_code')})"
)
try:
# Need to create a new socket object for reconnect
new_ws = websocket.WebSocket()
new_ws.connect(ws_url, timeout=10) # Use existing ws_url
print(f"worker-comfyui - Websocket reconnected successfully.")
return new_ws # Return the new connected socket
except (
websocket.WebSocketException,
ConnectionRefusedError,
socket.timeout,
OSError,
) as reconn_err:
last_reconnect_error = reconn_err
print(
f"worker-comfyui - Reconnect attempt {attempt + 1} failed: {reconn_err}"
)
if attempt < max_attempts - 1:
print(
f"worker-comfyui - Waiting {delay_s} seconds before next attempt..."
)
time.sleep(delay_s)
else:
print(f"worker-comfyui - Max reconnection attempts reached.")
# If loop completes without returning, raise an exception
print("worker-comfyui - Failed to reconnect websocket after connection closed.")
raise websocket.WebSocketConnectionClosedException(
f"Connection closed and failed to reconnect. Last error: {last_reconnect_error}"
)
def validate_input(job_input):
"""
Validates the input for the handler function.
Args:
job_input (dict): The input data to validate.
Returns:
tuple: A tuple containing the validated data and an error message, if any.
The structure is (validated_data, error_message).
"""
# Validate if job_input is provided
if job_input is None:
return None, "Please provide input"
# Check if input is a string and try to parse it as JSON
if isinstance(job_input, str):
try:
job_input = json.loads(job_input)
except json.JSONDecodeError:
return None, "Invalid JSON format in input"
# Validate 'workflow' in input
workflow = job_input.get("workflow")
if workflow is None:
return None, "Missing 'workflow' parameter"
# Validate 'images' in input, if provided
images = job_input.get("images")
if images is not None:
if not isinstance(images, list) or not all(
"name" in image and "image" in image for image in images
):
return (
None,
"'images' must be a list of objects with 'name' and 'image' keys",
)
# Return validated data and no error
return {"workflow": workflow, "images": images}, None
def check_server(url, retries=500, delay=50):
"""
Check if a server is reachable via HTTP GET request
Args:
- url (str): The URL to check
- retries (int, optional): The number of times to attempt connecting to the server. Default is 50
- delay (int, optional): The time in milliseconds to wait between retries. Default is 500
Returns:
bool: True if the server is reachable within the given number of retries, otherwise False
"""
print(f"worker-comfyui - Checking API server at {url}...")
for i in range(retries):
try:
response = requests.get(url, timeout=5)
# If the response status code is 200, the server is up and running
if response.status_code == 200:
print(f"worker-comfyui - API is reachable")
return True
except requests.Timeout:
pass
except requests.RequestException as e:
pass
# Wait for the specified delay before retrying
time.sleep(delay / 1000)
print(
f"worker-comfyui - Failed to connect to server at {url} after {retries} attempts."
)
return False
def upload_images(images):
"""
Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint.
Args:
images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string.
Returns:
dict: A dictionary indicating success or error.
"""
if not images:
return {"status": "success", "message": "No images to upload", "details": []}
responses = []
upload_errors = []
print(f"worker-comfyui - Uploading {len(images)} image(s)...")
for image in images:
try:
name = image["name"]
image_data_uri = image["image"] # Get the full string (might have prefix)
# --- Strip Data URI prefix if present ---
if "," in image_data_uri:
# Find the comma and take everything after it
base64_data = image_data_uri.split(",", 1)[1]
else:
# Assume it's already pure base64
base64_data = image_data_uri
# --- End strip ---
blob = base64.b64decode(base64_data) # Decode the cleaned data
# Prepare the form data
files = {
"image": (name, BytesIO(blob), "image/png"),
"overwrite": (None, "true"),
}
# POST request to upload the image
response = requests.post(
f"http://{COMFY_HOST}/upload/image", files=files, timeout=30
)
response.raise_for_status()
responses.append(f"Successfully uploaded {name}")
print(f"worker-comfyui - Successfully uploaded {name}")
except base64.binascii.Error as e:
error_msg = f"Error decoding base64 for {image.get('name', 'unknown')}: {e}"
print(f"worker-comfyui - {error_msg}")
upload_errors.append(error_msg)
except requests.Timeout:
error_msg = f"Timeout uploading {image.get('name', 'unknown')}"
print(f"worker-comfyui - {error_msg}")
upload_errors.append(error_msg)
except requests.RequestException as e:
error_msg = f"Error uploading {image.get('name', 'unknown')}: {e}"
print(f"worker-comfyui - {error_msg}")
upload_errors.append(error_msg)
except Exception as e:
error_msg = (
f"Unexpected error uploading {image.get('name', 'unknown')}: {e}"
)
print(f"worker-comfyui - {error_msg}")
upload_errors.append(error_msg)
if upload_errors:
print(f"worker-comfyui - image(s) upload finished with errors")
return {
"status": "error",
"message": "Some images failed to upload",
"details": upload_errors,
}
print(f"worker-comfyui - image(s) upload complete")
return {
"status": "success",
"message": "All images uploaded successfully",
"details": responses,
}
def get_available_models():
"""
Get list of available models from ComfyUI for various node types
Returns:
dict: Dictionary containing available models by type and node class
"""
try:
response = requests.get(f"http://{COMFY_HOST}/object_info", timeout=10)
response.raise_for_status()
object_info = response.json()
available_models = {}
# Extract available checkpoints from CheckpointLoaderSimple
if "CheckpointLoaderSimple" in object_info:
checkpoint_info = object_info["CheckpointLoaderSimple"]
if "input" in checkpoint_info and "required" in checkpoint_info["input"]:
ckpt_options = checkpoint_info["input"]["required"].get("ckpt_name")
if ckpt_options and len(ckpt_options) > 0:
available_models["checkpoints"] = (
ckpt_options[0] if isinstance(ckpt_options[0], list) else []
)
# Extract available CLIP models from DualCLIPLoader
if "DualCLIPLoader" in object_info:
dual_clip_info = object_info["DualCLIPLoader"]
if "input" in dual_clip_info and "required" in dual_clip_info["input"]:
clip1_options = dual_clip_info["input"]["required"].get("clip_name1")
clip2_options = dual_clip_info["input"]["required"].get("clip_name2")
if clip1_options and len(clip1_options) > 0:
available_models["clip_name1"] = (
clip1_options[0] if isinstance(clip1_options[0], list) else []
)
if clip2_options and len(clip2_options) > 0:
available_models["clip_name2"] = (
clip2_options[0] if isinstance(clip2_options[0], list) else []
)
# Extract available UNET models from UnetLoaderGGUF
if "UnetLoaderGGUF" in object_info:
unet_info = object_info["UnetLoaderGGUF"]
if "input" in unet_info and "required" in unet_info["input"]:
unet_options = unet_info["input"]["required"].get("unet_name")
if unet_options and len(unet_options) > 0:
available_models["unet_name"] = (
unet_options[0] if isinstance(unet_options[0], list) else []
)
return available_models
except Exception as e:
print(f"worker-comfyui - Warning: Could not fetch available models: {e}")
return {}
def queue_workflow(workflow, client_id):
"""
Queue a workflow to be processed by ComfyUI
Args:
workflow (dict): A dictionary containing the workflow to be processed
client_id (str): The client ID for the websocket connection
Returns:
dict: The JSON response from ComfyUI after processing the workflow
Raises:
ValueError: If the workflow validation fails with detailed error information
"""
# Include client_id in the prompt payload
payload = {"prompt": workflow, "client_id": client_id}
data = json.dumps(payload).encode("utf-8")
# Use requests for consistency and timeout
headers = {"Content-Type": "application/json"}
response = requests.post(
f"http://{COMFY_HOST}/prompt", data=data, headers=headers, timeout=30
)
# Handle validation errors with detailed information
if response.status_code == 400:
print(f"worker-comfyui - ComfyUI returned 400. Response body: {response.text}")
try:
error_data = response.json()
print(f"worker-comfyui - Parsed error data: {error_data}")
# Try to extract meaningful error information
error_message = "Workflow validation failed"
error_details = []
# ComfyUI seems to return different error formats, let's handle them all
if "error" in error_data:
error_info = error_data["error"]
if isinstance(error_info, dict):
error_message = error_info.get("message", error_message)
if error_info.get("type") == "prompt_outputs_failed_validation":
error_message = "Workflow validation failed"
else:
error_message = str(error_info)
# Check for node validation errors in the response
node_errors_dict = {}
if "node_errors" in error_data:
for node_id, node_error in error_data["node_errors"].items():
node_errors_dict[node_id] = node_error
if isinstance(node_error, dict):
# Extract class_type for later use
node_class = node_error.get("class_type", "Unknown")
# Format error details for display
if "errors" in node_error:
errors_list = node_error["errors"]
error_summary = []
for err in errors_list:
if isinstance(err, dict):
err_msg = err.get("message", "")
err_details = err.get("details", "")
if err_details:
error_summary.append(err_details)
elif err_msg:
error_summary.append(err_msg)
if error_summary:
error_details.append(
f"Node {node_id} ({node_class}): {'; '.join(error_summary)}"
)
else:
error_details.append(f"Node {node_id} ({node_class}): Validation error")
else:
# Fallback: format all keys as error details
for error_type, error_msg in node_error.items():
if error_type != "class_type" and error_type != "dependent_outputs":
error_details.append(
f"Node {node_id} ({error_type}): {error_msg}"
)
else:
error_details.append(f"Node {node_id}: {node_error}")
# Check if the error data itself contains validation info
if error_data.get("type") == "prompt_outputs_failed_validation":
error_message = error_data.get("message", "Workflow validation failed")
# For this type of error, we need to parse the validation details from logs
# Since ComfyUI doesn't seem to include detailed validation errors in the response
# Let's provide a more helpful generic message
available_models = get_available_models()
if available_models.get("checkpoints"):
error_message += f"\n\nThis usually means a required model or parameter is not available."
error_message += f"\nAvailable checkpoint models: {', '.join(available_models['checkpoints'])}"
else:
error_message += "\n\nThis usually means a required model or parameter is not available."
error_message += "\nNo checkpoint models appear to be available. Please check your model installation."
raise ValueError(error_message)
# If we have specific validation errors, format them nicely
if error_details:
detailed_message = f"{error_message}:\n" + "\n".join(
f"{detail}" for detail in error_details
)
# Extract model information from node errors and provide available alternatives
available_models = get_available_models()
model_suggestions = []
# Parse node_errors to extract missing model information
for node_id, node_error in node_errors_dict.items():
if isinstance(node_error, dict) and "errors" in node_error:
node_class = node_error.get("class_type", "Unknown")
for error_item in node_error["errors"]:
if isinstance(error_item, dict) and error_item.get("type") == "value_not_in_list":
extra_info = error_item.get("extra_info", {})
input_name = extra_info.get("input_name")
received_value = extra_info.get("received_value")
input_config = extra_info.get("input_config", [])
# Extract available values from input_config (format: [[list_of_available], {}])
available_list = []
if isinstance(input_config, list) and len(input_config) > 0:
if isinstance(input_config[0], list):
available_list = input_config[0]
# Map input names to available models from ComfyUI
if input_name == "clip_name1":
if available_models.get("clip_name1"):
model_suggestions.append(
f"Node {node_id} ({node_class}): Available clip_name1 models: {', '.join(available_models['clip_name1'])}"
)
elif available_list:
model_suggestions.append(
f"Node {node_id} ({node_class}): Available clip_name1 models: {', '.join(available_list)}"
)
else:
model_suggestions.append(
f"Node {node_id} ({node_class}): Requested '{received_value}' for clip_name1 is not available (no models found)."
)
elif input_name == "clip_name2":
if available_models.get("clip_name2"):
model_suggestions.append(
f"Node {node_id} ({node_class}): Available clip_name2 models: {', '.join(available_models['clip_name2'])}"
)
elif available_list:
model_suggestions.append(
f"Node {node_id} ({node_class}): Available clip_name2 models: {', '.join(available_list)}"
)
else:
model_suggestions.append(
f"Node {node_id} ({node_class}): Requested '{received_value}' for clip_name2 is not available (no models found)."
)
elif input_name == "unet_name":
if available_models.get("unet_name"):
model_suggestions.append(
f"Node {node_id} ({node_class}): Available unet_name models: {', '.join(available_models['unet_name'])}"
)
elif available_list:
model_suggestions.append(
f"Node {node_id} ({node_class}): Available unet_name models: {', '.join(available_list)}"
)
else:
model_suggestions.append(
f"Node {node_id} ({node_class}): Requested '{received_value}' for unet_name is not available (no models found)."
)
elif input_name == "ckpt_name":
if available_models.get("checkpoints"):
model_suggestions.append(
f"Node {node_id} ({node_class}): Available checkpoint models: {', '.join(available_models['checkpoints'])}"
)
elif available_list:
model_suggestions.append(
f"Node {node_id} ({node_class}): Available checkpoint models: {', '.join(available_list)}"
)
else:
model_suggestions.append(
f"Node {node_id} ({node_class}): Requested '{received_value}' for ckpt_name is not available (no models found)."
)
elif received_value:
# Generic fallback for any missing model
suggestion_key = f"{node_id}_{input_name}"
if not any(suggestion_key in s for s in model_suggestions):
if available_list:
model_suggestions.append(
f"Node {node_id} ({node_class}): Available {input_name} models: {', '.join(available_list)}"
)
else:
model_suggestions.append(
f"Node {node_id} ({node_class}): Requested '{received_value}' for '{input_name}' is not available."
)
# Add available model suggestions to the error message
if model_suggestions:
detailed_message += "\n\nAvailable models:\n" + "\n".join(
f"{suggestion}" for suggestion in model_suggestions
)
elif any("not in list" in detail for detail in error_details):
# Fallback: if we couldn't parse specific models, at least mention what we found
if available_models:
detailed_message += "\n\nNote: Some models may be available. Check your workflow configuration."
else:
detailed_message += "\n\nNo models appear to be available. Please check your model installation."
raise ValueError(detailed_message)
else:
# Fallback to the raw response if we can't parse specific errors
raise ValueError(f"{error_message}. Raw response: {response.text}")
except (json.JSONDecodeError, KeyError) as e:
# If we can't parse the error response, fall back to the raw text
raise ValueError(
f"ComfyUI validation failed (could not parse error response): {response.text}"
)
# For other HTTP errors, raise them normally
response.raise_for_status()
return response.json()
def get_history(prompt_id):
"""
Retrieve the history of a given prompt using its ID
Args:
prompt_id (str): The ID of the prompt whose history is to be retrieved
Returns:
dict: The history of the prompt, containing all the processing steps and results
"""
# Use requests for consistency and timeout
response = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}", timeout=30)
response.raise_for_status()
return response.json()
def get_file_data(filename, subfolder, file_type):
"""
Fetch file bytes from the ComfyUI /view endpoint.
Args:
filename (str): The filename of the file.
subfolder (str): The subfolder where the file is stored.
file_type (str): The type of the file (e.g., 'output').
Returns:
bytes: The raw file data, or None if an error occurs.
"""
print(
f"worker-comfyui - Fetching file data: type={file_type}, subfolder={subfolder}, filename={filename}"
)
data = {"filename": filename, "subfolder": subfolder, "type": file_type}
url_values = urllib.parse.urlencode(data)
try:
# Use requests for consistency and timeout
response = requests.get(f"http://{COMFY_HOST}/view?{url_values}", timeout=60)
response.raise_for_status()
print(f"worker-comfyui - Successfully fetched file data for {filename}")
return response.content
except requests.Timeout:
print(f"worker-comfyui - Timeout fetching file data for {filename}")
return None
except requests.RequestException as e:
print(f"worker-comfyui - Error fetching file data for {filename}: {e}")
return None
except Exception as e:
print(
f"worker-comfyui - Unexpected error fetching file data for {filename}: {e}"
)
return None
def handler(job):
"""
Handles a job using ComfyUI via websockets for status and file retrieval.
Args:
job (dict): A dictionary containing job details and input parameters.
Returns:
dict: A dictionary containing either an error message or a success status with generated files.
"""
job_input = job["input"]
job_id = job["id"]
# Make sure that the input is valid
validated_data, error_message = validate_input(job_input)
if error_message:
return {"error": error_message}
# Extract validated data
workflow = validated_data["workflow"]
input_images = validated_data.get("images")
# Make sure that the ComfyUI HTTP API is available before proceeding
if not check_server(
f"http://{COMFY_HOST}/",
COMFY_API_AVAILABLE_MAX_RETRIES,
COMFY_API_AVAILABLE_INTERVAL_MS,
):
return {
"error": f"ComfyUI server ({COMFY_HOST}) not reachable after multiple retries."
}
# Upload input images if they exist
if input_images:
upload_result = upload_images(input_images)
if upload_result["status"] == "error":
# Return upload errors
return {
"error": "Failed to upload one or more input images",
"details": upload_result["details"],
}
ws = None
client_id = str(uuid.uuid4())
prompt_id = None
output_images = []
output_audio = []
errors = []
try:
# Establish WebSocket connection
ws_url = f"ws://{COMFY_HOST}/ws?clientId={client_id}"
print(f"worker-comfyui - Connecting to websocket: {ws_url}")
ws = websocket.WebSocket()
ws.connect(ws_url, timeout=10)
print(f"worker-comfyui - Websocket connected")
# Queue the workflow
try:
queued_workflow = queue_workflow(workflow, client_id)
prompt_id = queued_workflow.get("prompt_id")
if not prompt_id:
raise ValueError(
f"Missing 'prompt_id' in queue response: {queued_workflow}"
)
print(f"worker-comfyui - Queued workflow with ID: {prompt_id}")
except requests.RequestException as e:
print(f"worker-comfyui - Error queuing workflow: {e}")
raise ValueError(f"Error queuing workflow: {e}")
except Exception as e:
print(f"worker-comfyui - Unexpected error queuing workflow: {e}")
# For ValueError exceptions from queue_workflow, pass through the original message
if isinstance(e, ValueError):
raise e
else:
raise ValueError(f"Unexpected error queuing workflow: {e}")
# Wait for execution completion via WebSocket
print(f"worker-comfyui - Waiting for workflow execution ({prompt_id})...")
execution_done = False
while True:
try:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message.get("type") == "status":
status_data = message.get("data", {}).get("status", {})
print(
f"worker-comfyui - Status update: {status_data.get('exec_info', {}).get('queue_remaining', 'N/A')} items remaining in queue"
)
elif message.get("type") == "executing":
data = message.get("data", {})
if (
data.get("node") is None
and data.get("prompt_id") == prompt_id
):
print(
f"worker-comfyui - Execution finished for prompt {prompt_id}"
)
execution_done = True
break
elif message.get("type") == "execution_error":
data = message.get("data", {})
if data.get("prompt_id") == prompt_id:
error_details = f"Node Type: {data.get('node_type')}, Node ID: {data.get('node_id')}, Message: {data.get('exception_message')}"
print(
f"worker-comfyui - Execution error received: {error_details}"
)
errors.append(f"Workflow execution error: {error_details}")
break
else:
continue
except websocket.WebSocketTimeoutException:
print(f"worker-comfyui - Websocket receive timed out. Still waiting...")
continue
except websocket.WebSocketConnectionClosedException as closed_err:
try:
# Attempt to reconnect
ws = _attempt_websocket_reconnect(
ws_url,
WEBSOCKET_RECONNECT_ATTEMPTS,
WEBSOCKET_RECONNECT_DELAY_S,
closed_err,
)
print(
"worker-comfyui - Resuming message listening after successful reconnect."
)
continue
except (
websocket.WebSocketConnectionClosedException
) as reconn_failed_err:
# If _attempt_websocket_reconnect fails, it raises this exception
# Let this exception propagate to the outer handler's except block
raise reconn_failed_err
except json.JSONDecodeError:
print(f"worker-comfyui - Received invalid JSON message via websocket.")
if not execution_done and not errors:
raise ValueError(
"Workflow monitoring loop exited without confirmation of completion or error."
)
# Fetch history even if there were execution errors, some outputs might exist
print(f"worker-comfyui - Fetching history for prompt {prompt_id}...")
history = get_history(prompt_id)
if prompt_id not in history:
error_msg = f"Prompt ID {prompt_id} not found in history after execution."
print(f"worker-comfyui - {error_msg}")
if not errors:
return {"error": error_msg}
else:
errors.append(error_msg)
return {
"error": "Job processing failed, prompt ID not found in history.",
"details": errors,
}
prompt_history = history.get(prompt_id, {})
outputs = prompt_history.get("outputs", {})
if not outputs:
warning_msg = f"No outputs found in history for prompt {prompt_id}."
print(f"worker-comfyui - {warning_msg}")
if not errors:
errors.append(warning_msg)
print(f"worker-comfyui - Processing {len(outputs)} output nodes...")
for node_id, node_output in outputs.items():
if "images" in node_output:
print(
f"worker-comfyui - Node {node_id} contains {len(node_output['images'])} image(s)"
)
for image_info in node_output["images"]:
filename = image_info.get("filename")
subfolder = image_info.get("subfolder", "")
img_type = image_info.get("type")
# skip temp images
if img_type == "temp":
print(
f"worker-comfyui - Skipping image {filename} because type is 'temp'"
)
continue
if not filename:
warn_msg = f"Skipping image in node {node_id} due to missing filename: {image_info}"
print(f"worker-comfyui - {warn_msg}")
errors.append(warn_msg)
continue
image_bytes = get_file_data(filename, subfolder, img_type)
if image_bytes:
file_extension = os.path.splitext(filename)[1] or ".png"
if os.environ.get("BUCKET_ENDPOINT_URL"):
try:
with tempfile.NamedTemporaryFile(
suffix=file_extension, delete=False
) as temp_file:
temp_file.write(image_bytes)
temp_file_path = temp_file.name
print(
f"worker-comfyui - Wrote image bytes to temporary file: {temp_file_path}"
)
print(f"worker-comfyui - Uploading {filename} to S3...")
s3_url = rp_upload.upload_image(job_id, temp_file_path)
os.remove(temp_file_path) # Clean up temp file
print(
f"worker-comfyui - Uploaded {filename} to S3: {s3_url}"
)
# Append dictionary with filename and URL
output_images.append(
{
"filename": filename,
"type": "s3_url",
"data": s3_url,
}
)
except Exception as e:
error_msg = f"Error uploading {filename} to S3: {e}"
print(f"worker-comfyui - {error_msg}")
errors.append(error_msg)
if "temp_file_path" in locals() and os.path.exists(
temp_file_path
):
try:
os.remove(temp_file_path)
except OSError as rm_err:
print(
f"worker-comfyui - Error removing temp file {temp_file_path}: {rm_err}"
)
else:
# Return as base64 string
try:
base64_image = base64.b64encode(image_bytes).decode(
"utf-8"
)
# Append dictionary with filename and base64 data
output_images.append(
{
"filename": filename,
"type": "base64",
"data": base64_image,
}
)
print(f"worker-comfyui - Encoded {filename} as base64")
except Exception as e:
error_msg = f"Error encoding {filename} to base64: {e}"
print(f"worker-comfyui - {error_msg}")
errors.append(error_msg)
else:
error_msg = f"Failed to fetch image data for {filename} from /view endpoint."
errors.append(error_msg)
if "audio" in node_output:
print(
f"worker-comfyui - Node {node_id} contains {len(node_output['audio'])} audio file(s)"
)
for audio_info in node_output["audio"]:
filename = audio_info.get("filename")
subfolder = audio_info.get("subfolder", "")
audio_type = audio_info.get("type")
if not filename:
warn_msg = f"Skipping audio in node {node_id} due to missing filename: {audio_info}"
print(f"worker-comfyui - {warn_msg}")
errors.append(warn_msg)
continue
audio_bytes = get_file_data(filename, subfolder, audio_type)
if audio_bytes:
file_extension = os.path.splitext(filename)[1] or ".wav"
if os.environ.get("BUCKET_ENDPOINT_URL"):
try:
with tempfile.NamedTemporaryFile(
suffix=file_extension, delete=False
) as temp_file:
temp_file.write(audio_bytes)
temp_file_path = temp_file.name
print(
f"worker-comfyui - Wrote audio bytes to temporary file: {temp_file_path}"
)
print(f"worker-comfyui - Uploading {filename} to S3...")
s3_url = rp_upload.upload_file(job_id, temp_file_path)
os.remove(temp_file_path)
print(
f"worker-comfyui - Uploaded {filename} to S3: {s3_url}"
)
output_audio.append(
{
"filename": filename,
"type": "s3_url",
"data": s3_url,
}
)
except Exception as e:
error_msg = f"Error uploading {filename} to S3: {e}"
print(f"worker-comfyui - {error_msg}")
errors.append(error_msg)
if "temp_file_path" in locals() and os.path.exists(
temp_file_path
):
try:
os.remove(temp_file_path)
except OSError as rm_err:
print(
f"worker-comfyui - Error removing temp file {temp_file_path}: {rm_err}"
)
else:
try:
base64_audio = base64.b64encode(audio_bytes).decode(
"utf-8"
)
output_audio.append(
{
"filename": filename,
"type": "base64",
"data": base64_audio,
}
)
print(f"worker-comfyui - Encoded {filename} as base64")
except Exception as e:
error_msg = f"Error encoding {filename} to base64: {e}"
print(f"worker-comfyui - {error_msg}")
errors.append(error_msg)
else:
error_msg = f"Failed to fetch audio data for {filename} from /view endpoint."
errors.append(error_msg)
# Check for other output types
other_keys = [k for k in node_output.keys() if k not in ["images", "audio"]]
if other_keys:
warn_msg = (
f"Node {node_id} produced unhandled output keys: {other_keys}."
)
print(f"worker-comfyui - WARNING: {warn_msg}")
print(
f"worker-comfyui - --> If this output is useful, please consider opening an issue on GitHub to discuss adding support."
)
except websocket.WebSocketException as e:
print(f"worker-comfyui - WebSocket Error: {e}")
print(traceback.format_exc())
return {"error": f"WebSocket communication error: {e}"}
except requests.RequestException as e:
print(f"worker-comfyui - HTTP Request Error: {e}")
print(traceback.format_exc())
return {"error": f"HTTP communication error with ComfyUI: {e}"}
except ValueError as e:
print(f"worker-comfyui - Value Error: {e}")
print(traceback.format_exc())
return {"error": str(e)}
except Exception as e:
print(f"worker-comfyui - Unexpected Handler Error: {e}")
print(traceback.format_exc())
return {"error": f"An unexpected error occurred: {e}"}
finally:
if ws and ws.connected:
print(f"worker-comfyui - Closing websocket connection.")
ws.close()
final_result = {}
if output_images:
final_result["images"] = output_images
if output_audio:
final_result["audio"] = output_audio
if errors:
final_result["errors"] = errors
print(f"worker-comfyui - Job completed with errors/warnings: {errors}")
if not output_images and not output_audio and errors:
print(f"worker-comfyui - Job failed with no output files.")
return {
"error": "Job processing failed",
"details": errors,
}
elif not output_images and not output_audio and not errors:
print(
f"worker-comfyui - Job completed successfully, but the workflow produced no images or audio."
)
final_result["status"] = "success_no_output"
print(f"worker-comfyui - Job completed. Returning {len(output_images)} image(s) and {len(output_audio)} audio file(s).")
return final_result
if __name__ == "__main__":
print("worker-comfyui - Starting handler...")
runpod.serverless.start({"handler": handler})