mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
1012 lines
46 KiB
Python
1012 lines
46 KiB
Python
import runpod
|
||
from runpod.serverless.utils import rp_upload
|
||
import json
|
||
import urllib.request
|
||
import urllib.parse
|
||
import time
|
||
import os
|
||
import requests
|
||
import base64
|
||
from io import BytesIO
|
||
import websocket
|
||
import uuid
|
||
import tempfile
|
||
import socket
|
||
import traceback
|
||
|
||
# Time to wait between API check attempts in milliseconds
|
||
COMFY_API_AVAILABLE_INTERVAL_MS = 50
|
||
# Maximum number of API check attempts
|
||
COMFY_API_AVAILABLE_MAX_RETRIES = 500
|
||
# Websocket reconnection behaviour (can be overridden through environment variables)
|
||
# NOTE: more attempts and diagnostics improve debuggability whenever ComfyUI crashes mid-job.
|
||
# • WEBSOCKET_RECONNECT_ATTEMPTS sets how many times we will try to reconnect.
|
||
# • WEBSOCKET_RECONNECT_DELAY_S sets the sleep in seconds between attempts.
|
||
#
|
||
# If the respective env-vars are not supplied we fall back to sensible defaults ("5" and "3").
|
||
WEBSOCKET_RECONNECT_ATTEMPTS = int(os.environ.get("WEBSOCKET_RECONNECT_ATTEMPTS", 5))
|
||
WEBSOCKET_RECONNECT_DELAY_S = int(os.environ.get("WEBSOCKET_RECONNECT_DELAY_S", 3))
|
||
|
||
# Extra verbose websocket trace logs (set WEBSOCKET_TRACE=true to enable)
|
||
if os.environ.get("WEBSOCKET_TRACE", "false").lower() == "true":
|
||
# This prints low-level frame information to stdout which is invaluable for diagnosing
|
||
# protocol errors but can be noisy in production – therefore gated behind an env-var.
|
||
websocket.enableTrace(True)
|
||
|
||
# Host where ComfyUI is running
|
||
COMFY_HOST = "127.0.0.1:8188"
|
||
# Enforce a clean state after each job is done
|
||
# see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker
|
||
REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true"
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helper: quick reachability probe of ComfyUI HTTP endpoint (port 8188)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _comfy_server_status():
|
||
"""Return a dictionary with basic reachability info for the ComfyUI HTTP server."""
|
||
try:
|
||
resp = requests.get(f"http://{COMFY_HOST}/", timeout=5)
|
||
return {
|
||
"reachable": resp.status_code == 200,
|
||
"status_code": resp.status_code,
|
||
}
|
||
except Exception as exc:
|
||
return {"reachable": False, "error": str(exc)}
|
||
|
||
|
||
def _attempt_websocket_reconnect(ws_url, max_attempts, delay_s, initial_error):
|
||
"""
|
||
Attempts to reconnect to the WebSocket server after a disconnect.
|
||
|
||
Args:
|
||
ws_url (str): The WebSocket URL (including client_id).
|
||
max_attempts (int): Maximum number of reconnection attempts.
|
||
delay_s (int): Delay in seconds between attempts.
|
||
initial_error (Exception): The error that triggered the reconnect attempt.
|
||
|
||
Returns:
|
||
websocket.WebSocket: The newly connected WebSocket object.
|
||
|
||
Raises:
|
||
websocket.WebSocketConnectionClosedException: If reconnection fails after all attempts.
|
||
"""
|
||
print(
|
||
f"worker-comfyui - Websocket connection closed unexpectedly: {initial_error}. Attempting to reconnect..."
|
||
)
|
||
last_reconnect_error = initial_error
|
||
for attempt in range(max_attempts):
|
||
# Log current server status before each reconnect attempt so that we can
|
||
# see whether ComfyUI is still alive (HTTP port 8188 responding) even if
|
||
# the websocket dropped. This is extremely useful to differentiate
|
||
# between a network glitch and an outright ComfyUI crash/OOM-kill.
|
||
srv_status = _comfy_server_status()
|
||
if not srv_status["reachable"]:
|
||
# If ComfyUI itself is down there is no point in retrying the websocket –
|
||
# bail out immediately so the caller gets a clear "ComfyUI crashed" error.
|
||
print(
|
||
f"worker-comfyui - ComfyUI HTTP unreachable – aborting websocket reconnect: {srv_status.get('error', 'status '+str(srv_status.get('status_code')))}"
|
||
)
|
||
raise websocket.WebSocketConnectionClosedException(
|
||
"ComfyUI HTTP unreachable during websocket reconnect"
|
||
)
|
||
|
||
# Otherwise we proceed with reconnect attempts while server is up
|
||
print(
|
||
f"worker-comfyui - Reconnect attempt {attempt + 1}/{max_attempts}... (ComfyUI HTTP reachable, status {srv_status.get('status_code')})"
|
||
)
|
||
try:
|
||
# Need to create a new socket object for reconnect
|
||
new_ws = websocket.WebSocket()
|
||
new_ws.connect(ws_url, timeout=10) # Use existing ws_url
|
||
print(f"worker-comfyui - Websocket reconnected successfully.")
|
||
return new_ws # Return the new connected socket
|
||
except (
|
||
websocket.WebSocketException,
|
||
ConnectionRefusedError,
|
||
socket.timeout,
|
||
OSError,
|
||
) as reconn_err:
|
||
last_reconnect_error = reconn_err
|
||
print(
|
||
f"worker-comfyui - Reconnect attempt {attempt + 1} failed: {reconn_err}"
|
||
)
|
||
if attempt < max_attempts - 1:
|
||
print(
|
||
f"worker-comfyui - Waiting {delay_s} seconds before next attempt..."
|
||
)
|
||
time.sleep(delay_s)
|
||
else:
|
||
print(f"worker-comfyui - Max reconnection attempts reached.")
|
||
|
||
# If loop completes without returning, raise an exception
|
||
print("worker-comfyui - Failed to reconnect websocket after connection closed.")
|
||
raise websocket.WebSocketConnectionClosedException(
|
||
f"Connection closed and failed to reconnect. Last error: {last_reconnect_error}"
|
||
)
|
||
|
||
|
||
def validate_input(job_input):
|
||
"""
|
||
Validates the input for the handler function.
|
||
|
||
Args:
|
||
job_input (dict): The input data to validate.
|
||
|
||
Returns:
|
||
tuple: A tuple containing the validated data and an error message, if any.
|
||
The structure is (validated_data, error_message).
|
||
"""
|
||
# Validate if job_input is provided
|
||
if job_input is None:
|
||
return None, "Please provide input"
|
||
|
||
# Check if input is a string and try to parse it as JSON
|
||
if isinstance(job_input, str):
|
||
try:
|
||
job_input = json.loads(job_input)
|
||
except json.JSONDecodeError:
|
||
return None, "Invalid JSON format in input"
|
||
|
||
# Validate 'workflow' in input
|
||
workflow = job_input.get("workflow")
|
||
if workflow is None:
|
||
return None, "Missing 'workflow' parameter"
|
||
|
||
# Validate 'images' in input, if provided
|
||
images = job_input.get("images")
|
||
if images is not None:
|
||
if not isinstance(images, list) or not all(
|
||
"name" in image and "image" in image for image in images
|
||
):
|
||
return (
|
||
None,
|
||
"'images' must be a list of objects with 'name' and 'image' keys",
|
||
)
|
||
|
||
# Return validated data and no error
|
||
return {"workflow": workflow, "images": images}, None
|
||
|
||
|
||
def check_server(url, retries=500, delay=50):
|
||
"""
|
||
Check if a server is reachable via HTTP GET request
|
||
|
||
Args:
|
||
- url (str): The URL to check
|
||
- retries (int, optional): The number of times to attempt connecting to the server. Default is 50
|
||
- delay (int, optional): The time in milliseconds to wait between retries. Default is 500
|
||
|
||
Returns:
|
||
bool: True if the server is reachable within the given number of retries, otherwise False
|
||
"""
|
||
|
||
print(f"worker-comfyui - Checking API server at {url}...")
|
||
for i in range(retries):
|
||
try:
|
||
response = requests.get(url, timeout=5)
|
||
|
||
# If the response status code is 200, the server is up and running
|
||
if response.status_code == 200:
|
||
print(f"worker-comfyui - API is reachable")
|
||
return True
|
||
except requests.Timeout:
|
||
pass
|
||
except requests.RequestException as e:
|
||
pass
|
||
|
||
# Wait for the specified delay before retrying
|
||
time.sleep(delay / 1000)
|
||
|
||
print(
|
||
f"worker-comfyui - Failed to connect to server at {url} after {retries} attempts."
|
||
)
|
||
return False
|
||
|
||
|
||
def upload_images(images):
|
||
"""
|
||
Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint.
|
||
|
||
Args:
|
||
images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string.
|
||
|
||
Returns:
|
||
dict: A dictionary indicating success or error.
|
||
"""
|
||
if not images:
|
||
return {"status": "success", "message": "No images to upload", "details": []}
|
||
|
||
responses = []
|
||
upload_errors = []
|
||
|
||
print(f"worker-comfyui - Uploading {len(images)} image(s)...")
|
||
|
||
for image in images:
|
||
try:
|
||
name = image["name"]
|
||
image_data_uri = image["image"] # Get the full string (might have prefix)
|
||
|
||
# --- Strip Data URI prefix if present ---
|
||
if "," in image_data_uri:
|
||
# Find the comma and take everything after it
|
||
base64_data = image_data_uri.split(",", 1)[1]
|
||
else:
|
||
# Assume it's already pure base64
|
||
base64_data = image_data_uri
|
||
# --- End strip ---
|
||
|
||
blob = base64.b64decode(base64_data) # Decode the cleaned data
|
||
|
||
# Prepare the form data
|
||
files = {
|
||
"image": (name, BytesIO(blob), "image/png"),
|
||
"overwrite": (None, "true"),
|
||
}
|
||
|
||
# POST request to upload the image
|
||
response = requests.post(
|
||
f"http://{COMFY_HOST}/upload/image", files=files, timeout=30
|
||
)
|
||
response.raise_for_status()
|
||
|
||
responses.append(f"Successfully uploaded {name}")
|
||
print(f"worker-comfyui - Successfully uploaded {name}")
|
||
|
||
except base64.binascii.Error as e:
|
||
error_msg = f"Error decoding base64 for {image.get('name', 'unknown')}: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
upload_errors.append(error_msg)
|
||
except requests.Timeout:
|
||
error_msg = f"Timeout uploading {image.get('name', 'unknown')}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
upload_errors.append(error_msg)
|
||
except requests.RequestException as e:
|
||
error_msg = f"Error uploading {image.get('name', 'unknown')}: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
upload_errors.append(error_msg)
|
||
except Exception as e:
|
||
error_msg = (
|
||
f"Unexpected error uploading {image.get('name', 'unknown')}: {e}"
|
||
)
|
||
print(f"worker-comfyui - {error_msg}")
|
||
upload_errors.append(error_msg)
|
||
|
||
if upload_errors:
|
||
print(f"worker-comfyui - image(s) upload finished with errors")
|
||
return {
|
||
"status": "error",
|
||
"message": "Some images failed to upload",
|
||
"details": upload_errors,
|
||
}
|
||
|
||
print(f"worker-comfyui - image(s) upload complete")
|
||
return {
|
||
"status": "success",
|
||
"message": "All images uploaded successfully",
|
||
"details": responses,
|
||
}
|
||
|
||
|
||
def get_available_models():
|
||
"""
|
||
Get list of available models from ComfyUI for various node types
|
||
|
||
Returns:
|
||
dict: Dictionary containing available models by type and node class
|
||
"""
|
||
try:
|
||
response = requests.get(f"http://{COMFY_HOST}/object_info", timeout=10)
|
||
response.raise_for_status()
|
||
object_info = response.json()
|
||
|
||
available_models = {}
|
||
|
||
# Extract available checkpoints from CheckpointLoaderSimple
|
||
if "CheckpointLoaderSimple" in object_info:
|
||
checkpoint_info = object_info["CheckpointLoaderSimple"]
|
||
if "input" in checkpoint_info and "required" in checkpoint_info["input"]:
|
||
ckpt_options = checkpoint_info["input"]["required"].get("ckpt_name")
|
||
if ckpt_options and len(ckpt_options) > 0:
|
||
available_models["checkpoints"] = (
|
||
ckpt_options[0] if isinstance(ckpt_options[0], list) else []
|
||
)
|
||
|
||
# Extract available CLIP models from DualCLIPLoader
|
||
if "DualCLIPLoader" in object_info:
|
||
dual_clip_info = object_info["DualCLIPLoader"]
|
||
if "input" in dual_clip_info and "required" in dual_clip_info["input"]:
|
||
clip1_options = dual_clip_info["input"]["required"].get("clip_name1")
|
||
clip2_options = dual_clip_info["input"]["required"].get("clip_name2")
|
||
if clip1_options and len(clip1_options) > 0:
|
||
available_models["clip_name1"] = (
|
||
clip1_options[0] if isinstance(clip1_options[0], list) else []
|
||
)
|
||
if clip2_options and len(clip2_options) > 0:
|
||
available_models["clip_name2"] = (
|
||
clip2_options[0] if isinstance(clip2_options[0], list) else []
|
||
)
|
||
|
||
# Extract available UNET models from UnetLoaderGGUF
|
||
if "UnetLoaderGGUF" in object_info:
|
||
unet_info = object_info["UnetLoaderGGUF"]
|
||
if "input" in unet_info and "required" in unet_info["input"]:
|
||
unet_options = unet_info["input"]["required"].get("unet_name")
|
||
if unet_options and len(unet_options) > 0:
|
||
available_models["unet_name"] = (
|
||
unet_options[0] if isinstance(unet_options[0], list) else []
|
||
)
|
||
|
||
return available_models
|
||
except Exception as e:
|
||
print(f"worker-comfyui - Warning: Could not fetch available models: {e}")
|
||
return {}
|
||
|
||
|
||
def queue_workflow(workflow, client_id):
|
||
"""
|
||
Queue a workflow to be processed by ComfyUI
|
||
|
||
Args:
|
||
workflow (dict): A dictionary containing the workflow to be processed
|
||
client_id (str): The client ID for the websocket connection
|
||
|
||
Returns:
|
||
dict: The JSON response from ComfyUI after processing the workflow
|
||
|
||
Raises:
|
||
ValueError: If the workflow validation fails with detailed error information
|
||
"""
|
||
# Include client_id in the prompt payload
|
||
payload = {"prompt": workflow, "client_id": client_id}
|
||
data = json.dumps(payload).encode("utf-8")
|
||
|
||
# Use requests for consistency and timeout
|
||
headers = {"Content-Type": "application/json"}
|
||
response = requests.post(
|
||
f"http://{COMFY_HOST}/prompt", data=data, headers=headers, timeout=30
|
||
)
|
||
|
||
# Handle validation errors with detailed information
|
||
if response.status_code == 400:
|
||
print(f"worker-comfyui - ComfyUI returned 400. Response body: {response.text}")
|
||
try:
|
||
error_data = response.json()
|
||
print(f"worker-comfyui - Parsed error data: {error_data}")
|
||
|
||
# Try to extract meaningful error information
|
||
error_message = "Workflow validation failed"
|
||
error_details = []
|
||
|
||
# ComfyUI seems to return different error formats, let's handle them all
|
||
if "error" in error_data:
|
||
error_info = error_data["error"]
|
||
if isinstance(error_info, dict):
|
||
error_message = error_info.get("message", error_message)
|
||
if error_info.get("type") == "prompt_outputs_failed_validation":
|
||
error_message = "Workflow validation failed"
|
||
else:
|
||
error_message = str(error_info)
|
||
|
||
# Check for node validation errors in the response
|
||
node_errors_dict = {}
|
||
if "node_errors" in error_data:
|
||
for node_id, node_error in error_data["node_errors"].items():
|
||
node_errors_dict[node_id] = node_error
|
||
if isinstance(node_error, dict):
|
||
# Extract class_type for later use
|
||
node_class = node_error.get("class_type", "Unknown")
|
||
# Format error details for display
|
||
if "errors" in node_error:
|
||
errors_list = node_error["errors"]
|
||
error_summary = []
|
||
for err in errors_list:
|
||
if isinstance(err, dict):
|
||
err_msg = err.get("message", "")
|
||
err_details = err.get("details", "")
|
||
if err_details:
|
||
error_summary.append(err_details)
|
||
elif err_msg:
|
||
error_summary.append(err_msg)
|
||
if error_summary:
|
||
error_details.append(
|
||
f"Node {node_id} ({node_class}): {'; '.join(error_summary)}"
|
||
)
|
||
else:
|
||
error_details.append(f"Node {node_id} ({node_class}): Validation error")
|
||
else:
|
||
# Fallback: format all keys as error details
|
||
for error_type, error_msg in node_error.items():
|
||
if error_type != "class_type" and error_type != "dependent_outputs":
|
||
error_details.append(
|
||
f"Node {node_id} ({error_type}): {error_msg}"
|
||
)
|
||
else:
|
||
error_details.append(f"Node {node_id}: {node_error}")
|
||
|
||
# Check if the error data itself contains validation info
|
||
if error_data.get("type") == "prompt_outputs_failed_validation":
|
||
error_message = error_data.get("message", "Workflow validation failed")
|
||
# For this type of error, we need to parse the validation details from logs
|
||
# Since ComfyUI doesn't seem to include detailed validation errors in the response
|
||
# Let's provide a more helpful generic message
|
||
available_models = get_available_models()
|
||
if available_models.get("checkpoints"):
|
||
error_message += f"\n\nThis usually means a required model or parameter is not available."
|
||
error_message += f"\nAvailable checkpoint models: {', '.join(available_models['checkpoints'])}"
|
||
else:
|
||
error_message += "\n\nThis usually means a required model or parameter is not available."
|
||
error_message += "\nNo checkpoint models appear to be available. Please check your model installation."
|
||
|
||
raise ValueError(error_message)
|
||
|
||
# If we have specific validation errors, format them nicely
|
||
if error_details:
|
||
detailed_message = f"{error_message}:\n" + "\n".join(
|
||
f"• {detail}" for detail in error_details
|
||
)
|
||
|
||
# Extract model information from node errors and provide available alternatives
|
||
available_models = get_available_models()
|
||
model_suggestions = []
|
||
|
||
# Parse node_errors to extract missing model information
|
||
for node_id, node_error in node_errors_dict.items():
|
||
if isinstance(node_error, dict) and "errors" in node_error:
|
||
node_class = node_error.get("class_type", "Unknown")
|
||
for error_item in node_error["errors"]:
|
||
if isinstance(error_item, dict) and error_item.get("type") == "value_not_in_list":
|
||
extra_info = error_item.get("extra_info", {})
|
||
input_name = extra_info.get("input_name")
|
||
received_value = extra_info.get("received_value")
|
||
input_config = extra_info.get("input_config", [])
|
||
|
||
# Extract available values from input_config (format: [[list_of_available], {}])
|
||
available_list = []
|
||
if isinstance(input_config, list) and len(input_config) > 0:
|
||
if isinstance(input_config[0], list):
|
||
available_list = input_config[0]
|
||
|
||
# Map input names to available models from ComfyUI
|
||
if input_name == "clip_name1":
|
||
if available_models.get("clip_name1"):
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available clip_name1 models: {', '.join(available_models['clip_name1'])}"
|
||
)
|
||
elif available_list:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available clip_name1 models: {', '.join(available_list)}"
|
||
)
|
||
else:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Requested '{received_value}' for clip_name1 is not available (no models found)."
|
||
)
|
||
elif input_name == "clip_name2":
|
||
if available_models.get("clip_name2"):
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available clip_name2 models: {', '.join(available_models['clip_name2'])}"
|
||
)
|
||
elif available_list:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available clip_name2 models: {', '.join(available_list)}"
|
||
)
|
||
else:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Requested '{received_value}' for clip_name2 is not available (no models found)."
|
||
)
|
||
elif input_name == "unet_name":
|
||
if available_models.get("unet_name"):
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available unet_name models: {', '.join(available_models['unet_name'])}"
|
||
)
|
||
elif available_list:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available unet_name models: {', '.join(available_list)}"
|
||
)
|
||
else:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Requested '{received_value}' for unet_name is not available (no models found)."
|
||
)
|
||
elif input_name == "ckpt_name":
|
||
if available_models.get("checkpoints"):
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available checkpoint models: {', '.join(available_models['checkpoints'])}"
|
||
)
|
||
elif available_list:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available checkpoint models: {', '.join(available_list)}"
|
||
)
|
||
else:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Requested '{received_value}' for ckpt_name is not available (no models found)."
|
||
)
|
||
elif received_value:
|
||
# Generic fallback for any missing model
|
||
suggestion_key = f"{node_id}_{input_name}"
|
||
if not any(suggestion_key in s for s in model_suggestions):
|
||
if available_list:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Available {input_name} models: {', '.join(available_list)}"
|
||
)
|
||
else:
|
||
model_suggestions.append(
|
||
f"Node {node_id} ({node_class}): Requested '{received_value}' for '{input_name}' is not available."
|
||
)
|
||
|
||
# Add available model suggestions to the error message
|
||
if model_suggestions:
|
||
detailed_message += "\n\nAvailable models:\n" + "\n".join(
|
||
f" • {suggestion}" for suggestion in model_suggestions
|
||
)
|
||
elif any("not in list" in detail for detail in error_details):
|
||
# Fallback: if we couldn't parse specific models, at least mention what we found
|
||
if available_models:
|
||
detailed_message += "\n\nNote: Some models may be available. Check your workflow configuration."
|
||
else:
|
||
detailed_message += "\n\nNo models appear to be available. Please check your model installation."
|
||
|
||
raise ValueError(detailed_message)
|
||
else:
|
||
# Fallback to the raw response if we can't parse specific errors
|
||
raise ValueError(f"{error_message}. Raw response: {response.text}")
|
||
|
||
except (json.JSONDecodeError, KeyError) as e:
|
||
# If we can't parse the error response, fall back to the raw text
|
||
raise ValueError(
|
||
f"ComfyUI validation failed (could not parse error response): {response.text}"
|
||
)
|
||
|
||
# For other HTTP errors, raise them normally
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
|
||
def get_history(prompt_id):
|
||
"""
|
||
Retrieve the history of a given prompt using its ID
|
||
|
||
Args:
|
||
prompt_id (str): The ID of the prompt whose history is to be retrieved
|
||
|
||
Returns:
|
||
dict: The history of the prompt, containing all the processing steps and results
|
||
"""
|
||
# Use requests for consistency and timeout
|
||
response = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}", timeout=30)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
|
||
def get_file_data(filename, subfolder, file_type):
|
||
"""
|
||
Fetch file bytes from the ComfyUI /view endpoint.
|
||
|
||
Args:
|
||
filename (str): The filename of the file.
|
||
subfolder (str): The subfolder where the file is stored.
|
||
file_type (str): The type of the file (e.g., 'output').
|
||
|
||
Returns:
|
||
bytes: The raw file data, or None if an error occurs.
|
||
"""
|
||
print(
|
||
f"worker-comfyui - Fetching file data: type={file_type}, subfolder={subfolder}, filename={filename}"
|
||
)
|
||
data = {"filename": filename, "subfolder": subfolder, "type": file_type}
|
||
url_values = urllib.parse.urlencode(data)
|
||
try:
|
||
# Use requests for consistency and timeout
|
||
response = requests.get(f"http://{COMFY_HOST}/view?{url_values}", timeout=60)
|
||
response.raise_for_status()
|
||
print(f"worker-comfyui - Successfully fetched file data for {filename}")
|
||
return response.content
|
||
except requests.Timeout:
|
||
print(f"worker-comfyui - Timeout fetching file data for {filename}")
|
||
return None
|
||
except requests.RequestException as e:
|
||
print(f"worker-comfyui - Error fetching file data for {filename}: {e}")
|
||
return None
|
||
except Exception as e:
|
||
print(
|
||
f"worker-comfyui - Unexpected error fetching file data for {filename}: {e}"
|
||
)
|
||
return None
|
||
|
||
|
||
def handler(job):
|
||
"""
|
||
Handles a job using ComfyUI via websockets for status and file retrieval.
|
||
|
||
Args:
|
||
job (dict): A dictionary containing job details and input parameters.
|
||
|
||
Returns:
|
||
dict: A dictionary containing either an error message or a success status with generated files.
|
||
"""
|
||
job_input = job["input"]
|
||
job_id = job["id"]
|
||
|
||
# Make sure that the input is valid
|
||
validated_data, error_message = validate_input(job_input)
|
||
if error_message:
|
||
return {"error": error_message}
|
||
|
||
# Extract validated data
|
||
workflow = validated_data["workflow"]
|
||
input_images = validated_data.get("images")
|
||
|
||
# Make sure that the ComfyUI HTTP API is available before proceeding
|
||
if not check_server(
|
||
f"http://{COMFY_HOST}/",
|
||
COMFY_API_AVAILABLE_MAX_RETRIES,
|
||
COMFY_API_AVAILABLE_INTERVAL_MS,
|
||
):
|
||
return {
|
||
"error": f"ComfyUI server ({COMFY_HOST}) not reachable after multiple retries."
|
||
}
|
||
|
||
# Upload input images if they exist
|
||
if input_images:
|
||
upload_result = upload_images(input_images)
|
||
if upload_result["status"] == "error":
|
||
# Return upload errors
|
||
return {
|
||
"error": "Failed to upload one or more input images",
|
||
"details": upload_result["details"],
|
||
}
|
||
|
||
ws = None
|
||
client_id = str(uuid.uuid4())
|
||
prompt_id = None
|
||
output_images = []
|
||
output_audio = []
|
||
errors = []
|
||
|
||
try:
|
||
# Establish WebSocket connection
|
||
ws_url = f"ws://{COMFY_HOST}/ws?clientId={client_id}"
|
||
print(f"worker-comfyui - Connecting to websocket: {ws_url}")
|
||
ws = websocket.WebSocket()
|
||
ws.connect(ws_url, timeout=10)
|
||
print(f"worker-comfyui - Websocket connected")
|
||
|
||
# Queue the workflow
|
||
try:
|
||
queued_workflow = queue_workflow(workflow, client_id)
|
||
prompt_id = queued_workflow.get("prompt_id")
|
||
if not prompt_id:
|
||
raise ValueError(
|
||
f"Missing 'prompt_id' in queue response: {queued_workflow}"
|
||
)
|
||
print(f"worker-comfyui - Queued workflow with ID: {prompt_id}")
|
||
except requests.RequestException as e:
|
||
print(f"worker-comfyui - Error queuing workflow: {e}")
|
||
raise ValueError(f"Error queuing workflow: {e}")
|
||
except Exception as e:
|
||
print(f"worker-comfyui - Unexpected error queuing workflow: {e}")
|
||
# For ValueError exceptions from queue_workflow, pass through the original message
|
||
if isinstance(e, ValueError):
|
||
raise e
|
||
else:
|
||
raise ValueError(f"Unexpected error queuing workflow: {e}")
|
||
|
||
# Wait for execution completion via WebSocket
|
||
print(f"worker-comfyui - Waiting for workflow execution ({prompt_id})...")
|
||
execution_done = False
|
||
while True:
|
||
try:
|
||
out = ws.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
if message.get("type") == "status":
|
||
status_data = message.get("data", {}).get("status", {})
|
||
print(
|
||
f"worker-comfyui - Status update: {status_data.get('exec_info', {}).get('queue_remaining', 'N/A')} items remaining in queue"
|
||
)
|
||
elif message.get("type") == "executing":
|
||
data = message.get("data", {})
|
||
if (
|
||
data.get("node") is None
|
||
and data.get("prompt_id") == prompt_id
|
||
):
|
||
print(
|
||
f"worker-comfyui - Execution finished for prompt {prompt_id}"
|
||
)
|
||
execution_done = True
|
||
break
|
||
elif message.get("type") == "execution_error":
|
||
data = message.get("data", {})
|
||
if data.get("prompt_id") == prompt_id:
|
||
error_details = f"Node Type: {data.get('node_type')}, Node ID: {data.get('node_id')}, Message: {data.get('exception_message')}"
|
||
print(
|
||
f"worker-comfyui - Execution error received: {error_details}"
|
||
)
|
||
errors.append(f"Workflow execution error: {error_details}")
|
||
break
|
||
else:
|
||
continue
|
||
except websocket.WebSocketTimeoutException:
|
||
print(f"worker-comfyui - Websocket receive timed out. Still waiting...")
|
||
continue
|
||
except websocket.WebSocketConnectionClosedException as closed_err:
|
||
try:
|
||
# Attempt to reconnect
|
||
ws = _attempt_websocket_reconnect(
|
||
ws_url,
|
||
WEBSOCKET_RECONNECT_ATTEMPTS,
|
||
WEBSOCKET_RECONNECT_DELAY_S,
|
||
closed_err,
|
||
)
|
||
|
||
print(
|
||
"worker-comfyui - Resuming message listening after successful reconnect."
|
||
)
|
||
continue
|
||
except (
|
||
websocket.WebSocketConnectionClosedException
|
||
) as reconn_failed_err:
|
||
# If _attempt_websocket_reconnect fails, it raises this exception
|
||
# Let this exception propagate to the outer handler's except block
|
||
raise reconn_failed_err
|
||
|
||
except json.JSONDecodeError:
|
||
print(f"worker-comfyui - Received invalid JSON message via websocket.")
|
||
|
||
if not execution_done and not errors:
|
||
raise ValueError(
|
||
"Workflow monitoring loop exited without confirmation of completion or error."
|
||
)
|
||
|
||
# Fetch history even if there were execution errors, some outputs might exist
|
||
print(f"worker-comfyui - Fetching history for prompt {prompt_id}...")
|
||
history = get_history(prompt_id)
|
||
|
||
if prompt_id not in history:
|
||
error_msg = f"Prompt ID {prompt_id} not found in history after execution."
|
||
print(f"worker-comfyui - {error_msg}")
|
||
if not errors:
|
||
return {"error": error_msg}
|
||
else:
|
||
errors.append(error_msg)
|
||
return {
|
||
"error": "Job processing failed, prompt ID not found in history.",
|
||
"details": errors,
|
||
}
|
||
|
||
prompt_history = history.get(prompt_id, {})
|
||
outputs = prompt_history.get("outputs", {})
|
||
|
||
if not outputs:
|
||
warning_msg = f"No outputs found in history for prompt {prompt_id}."
|
||
print(f"worker-comfyui - {warning_msg}")
|
||
if not errors:
|
||
errors.append(warning_msg)
|
||
|
||
print(f"worker-comfyui - Processing {len(outputs)} output nodes...")
|
||
for node_id, node_output in outputs.items():
|
||
if "images" in node_output:
|
||
print(
|
||
f"worker-comfyui - Node {node_id} contains {len(node_output['images'])} image(s)"
|
||
)
|
||
for image_info in node_output["images"]:
|
||
filename = image_info.get("filename")
|
||
subfolder = image_info.get("subfolder", "")
|
||
img_type = image_info.get("type")
|
||
|
||
# skip temp images
|
||
if img_type == "temp":
|
||
print(
|
||
f"worker-comfyui - Skipping image {filename} because type is 'temp'"
|
||
)
|
||
continue
|
||
|
||
if not filename:
|
||
warn_msg = f"Skipping image in node {node_id} due to missing filename: {image_info}"
|
||
print(f"worker-comfyui - {warn_msg}")
|
||
errors.append(warn_msg)
|
||
continue
|
||
|
||
image_bytes = get_file_data(filename, subfolder, img_type)
|
||
|
||
if image_bytes:
|
||
file_extension = os.path.splitext(filename)[1] or ".png"
|
||
|
||
if os.environ.get("BUCKET_ENDPOINT_URL"):
|
||
try:
|
||
with tempfile.NamedTemporaryFile(
|
||
suffix=file_extension, delete=False
|
||
) as temp_file:
|
||
temp_file.write(image_bytes)
|
||
temp_file_path = temp_file.name
|
||
print(
|
||
f"worker-comfyui - Wrote image bytes to temporary file: {temp_file_path}"
|
||
)
|
||
|
||
print(f"worker-comfyui - Uploading {filename} to S3...")
|
||
s3_url = rp_upload.upload_image(job_id, temp_file_path)
|
||
os.remove(temp_file_path) # Clean up temp file
|
||
print(
|
||
f"worker-comfyui - Uploaded {filename} to S3: {s3_url}"
|
||
)
|
||
# Append dictionary with filename and URL
|
||
output_images.append(
|
||
{
|
||
"filename": filename,
|
||
"type": "s3_url",
|
||
"data": s3_url,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
error_msg = f"Error uploading {filename} to S3: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
errors.append(error_msg)
|
||
if "temp_file_path" in locals() and os.path.exists(
|
||
temp_file_path
|
||
):
|
||
try:
|
||
os.remove(temp_file_path)
|
||
except OSError as rm_err:
|
||
print(
|
||
f"worker-comfyui - Error removing temp file {temp_file_path}: {rm_err}"
|
||
)
|
||
else:
|
||
# Return as base64 string
|
||
try:
|
||
base64_image = base64.b64encode(image_bytes).decode(
|
||
"utf-8"
|
||
)
|
||
# Append dictionary with filename and base64 data
|
||
output_images.append(
|
||
{
|
||
"filename": filename,
|
||
"type": "base64",
|
||
"data": base64_image,
|
||
}
|
||
)
|
||
print(f"worker-comfyui - Encoded {filename} as base64")
|
||
except Exception as e:
|
||
error_msg = f"Error encoding {filename} to base64: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
errors.append(error_msg)
|
||
else:
|
||
error_msg = f"Failed to fetch image data for {filename} from /view endpoint."
|
||
errors.append(error_msg)
|
||
|
||
if "audio" in node_output:
|
||
print(
|
||
f"worker-comfyui - Node {node_id} contains {len(node_output['audio'])} audio file(s)"
|
||
)
|
||
for audio_info in node_output["audio"]:
|
||
filename = audio_info.get("filename")
|
||
subfolder = audio_info.get("subfolder", "")
|
||
audio_type = audio_info.get("type")
|
||
if not filename:
|
||
warn_msg = f"Skipping audio in node {node_id} due to missing filename: {audio_info}"
|
||
print(f"worker-comfyui - {warn_msg}")
|
||
errors.append(warn_msg)
|
||
continue
|
||
audio_bytes = get_file_data(filename, subfolder, audio_type)
|
||
if audio_bytes:
|
||
file_extension = os.path.splitext(filename)[1] or ".wav"
|
||
if os.environ.get("BUCKET_ENDPOINT_URL"):
|
||
try:
|
||
with tempfile.NamedTemporaryFile(
|
||
suffix=file_extension, delete=False
|
||
) as temp_file:
|
||
temp_file.write(audio_bytes)
|
||
temp_file_path = temp_file.name
|
||
print(
|
||
f"worker-comfyui - Wrote audio bytes to temporary file: {temp_file_path}"
|
||
)
|
||
print(f"worker-comfyui - Uploading {filename} to S3...")
|
||
s3_url = rp_upload.upload_file(job_id, temp_file_path)
|
||
os.remove(temp_file_path)
|
||
print(
|
||
f"worker-comfyui - Uploaded {filename} to S3: {s3_url}"
|
||
)
|
||
output_audio.append(
|
||
{
|
||
"filename": filename,
|
||
"type": "s3_url",
|
||
"data": s3_url,
|
||
}
|
||
)
|
||
except Exception as e:
|
||
error_msg = f"Error uploading {filename} to S3: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
errors.append(error_msg)
|
||
if "temp_file_path" in locals() and os.path.exists(
|
||
temp_file_path
|
||
):
|
||
try:
|
||
os.remove(temp_file_path)
|
||
except OSError as rm_err:
|
||
print(
|
||
f"worker-comfyui - Error removing temp file {temp_file_path}: {rm_err}"
|
||
)
|
||
else:
|
||
try:
|
||
base64_audio = base64.b64encode(audio_bytes).decode(
|
||
"utf-8"
|
||
)
|
||
output_audio.append(
|
||
{
|
||
"filename": filename,
|
||
"type": "base64",
|
||
"data": base64_audio,
|
||
}
|
||
)
|
||
print(f"worker-comfyui - Encoded {filename} as base64")
|
||
except Exception as e:
|
||
error_msg = f"Error encoding {filename} to base64: {e}"
|
||
print(f"worker-comfyui - {error_msg}")
|
||
errors.append(error_msg)
|
||
else:
|
||
error_msg = f"Failed to fetch audio data for {filename} from /view endpoint."
|
||
errors.append(error_msg)
|
||
|
||
# Check for other output types
|
||
other_keys = [k for k in node_output.keys() if k not in ["images", "audio"]]
|
||
if other_keys:
|
||
warn_msg = (
|
||
f"Node {node_id} produced unhandled output keys: {other_keys}."
|
||
)
|
||
print(f"worker-comfyui - WARNING: {warn_msg}")
|
||
print(
|
||
f"worker-comfyui - --> If this output is useful, please consider opening an issue on GitHub to discuss adding support."
|
||
)
|
||
|
||
except websocket.WebSocketException as e:
|
||
print(f"worker-comfyui - WebSocket Error: {e}")
|
||
print(traceback.format_exc())
|
||
return {"error": f"WebSocket communication error: {e}"}
|
||
except requests.RequestException as e:
|
||
print(f"worker-comfyui - HTTP Request Error: {e}")
|
||
print(traceback.format_exc())
|
||
return {"error": f"HTTP communication error with ComfyUI: {e}"}
|
||
except ValueError as e:
|
||
print(f"worker-comfyui - Value Error: {e}")
|
||
print(traceback.format_exc())
|
||
return {"error": str(e)}
|
||
except Exception as e:
|
||
print(f"worker-comfyui - Unexpected Handler Error: {e}")
|
||
print(traceback.format_exc())
|
||
return {"error": f"An unexpected error occurred: {e}"}
|
||
finally:
|
||
if ws and ws.connected:
|
||
print(f"worker-comfyui - Closing websocket connection.")
|
||
ws.close()
|
||
|
||
final_result = {}
|
||
|
||
if output_images:
|
||
final_result["images"] = output_images
|
||
|
||
if output_audio:
|
||
final_result["audio"] = output_audio
|
||
|
||
if errors:
|
||
final_result["errors"] = errors
|
||
print(f"worker-comfyui - Job completed with errors/warnings: {errors}")
|
||
|
||
if not output_images and not output_audio and errors:
|
||
print(f"worker-comfyui - Job failed with no output files.")
|
||
return {
|
||
"error": "Job processing failed",
|
||
"details": errors,
|
||
}
|
||
elif not output_images and not output_audio and not errors:
|
||
print(
|
||
f"worker-comfyui - Job completed successfully, but the workflow produced no images or audio."
|
||
)
|
||
final_result["status"] = "success_no_output"
|
||
|
||
print(f"worker-comfyui - Job completed. Returning {len(output_images)} image(s) and {len(output_audio)} audio file(s).")
|
||
return final_result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("worker-comfyui - Starting handler...")
|
||
runpod.serverless.start({"handler": handler})
|