easier support for colab

This commit is contained in:
Benjamin Berman 2025-08-14 09:50:46 -07:00
parent deba203176
commit b3d95afcab
2 changed files with 25 additions and 15 deletions

View File

@ -4,15 +4,10 @@ import stat
import subprocess import subprocess
import threading import threading
from asyncio import Task from asyncio import Task
from typing import NamedTuple from typing import NamedTuple, Optional
import requests import requests
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths # pylint: disable=import-error
# experimental workarounds for colab
from ..cmd.main import _start_comfyui
from ..execution_context import *
class _ColabTuple(NamedTuple): class _ColabTuple(NamedTuple):
tunnel: "CloudflaredTunnel" tunnel: "CloudflaredTunnel"
@ -52,23 +47,22 @@ class CloudflaredTunnel:
for chunk in response.iter_content(chunk_size=8192): for chunk in response.iter_content(chunk_size=8192):
f.write(chunk) f.write(chunk)
# Make the file executable (add execute permission for the owner) current_permissions = os.stat(self._executable_path).st_mode
current_permissions = os.stat(self._executable_path).st_mode os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
def _start_tunnel(self) -> str: def _start_tunnel(self) -> str:
"""Starts the tunnel and returns the public URL.""" """Starts the tunnel and returns the public URL."""
command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"] command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"]
# Using DEVNULL for stderr to keep the output clean, stdout is piped
self._process = subprocess.Popen( self._process = subprocess.Popen(
command, command,
bufsize=1,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL, stderr=subprocess.STDOUT,
text=True text=True
) )
for line in iter(self._process.stdout.readline, ""): for line in self._process.stdout:
if ".trycloudflare.com" in line: if ".trycloudflare.com" in line:
# The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |" # The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |"
try: try:
@ -78,7 +72,6 @@ class CloudflaredTunnel:
except IndexError: except IndexError:
continue continue
# If the loop finishes without finding a URL
self.stop() self.stop()
raise RuntimeError("Failed to start cloudflared tunnel or find URL.") raise RuntimeError("Failed to start cloudflared tunnel or find URL.")
@ -124,8 +117,17 @@ def start_server_in_colab() -> str:
:return: :return:
""" """
if len(_colab_instances) == 0: if len(_colab_instances) == 0:
from ..execution_context import ExecutionContext, ServerStub, comfyui_execution_context
from ..component_model.folder_path_types import FolderNames
from ..nodes.package_typing import ExportedNodes
from ..progress_types import ProgressRegistryStub
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub())) comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
# now we're ready to import
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
# experimental workarounds for colab
from ..cmd.main import _start_comfyui
async def colab_server_loop(): async def colab_server_loop():
init_default_paths(folder_names_and_paths) init_default_paths(folder_names_and_paths)
await _start_comfyui() await _start_comfyui()

View File

@ -4,9 +4,11 @@ import asyncio
import copy import copy
import gc import gc
import json import json
import logging
import threading import threading
import uuid import uuid
from asyncio import get_event_loop from asyncio import get_event_loop
from dataclasses import dataclass
from multiprocessing import RLock from multiprocessing import RLock
from typing import Optional from typing import Optional
@ -28,6 +30,8 @@ from ..execution_context import current_execution_context
_prompt_executor = threading.local() _prompt_executor = threading.local()
logger = logging.getLogger(__name__)
def _execute_prompt( def _execute_prompt(
prompt: dict, prompt: dict,
@ -216,7 +220,11 @@ class Comfy:
prompt: PromptDict | dict, prompt: PromptDict | dict,
prompt_id: Optional[str] = None, prompt_id: Optional[str] = None,
client_id: Optional[str] = None, client_id: Optional[str] = None,
partial_execution_targets: Optional[list[str]] = None) -> dict: partial_execution_targets: Optional[list[str]] = None,
progress_handler: Optional[ExecutorToClientProgress] = None) -> dict:
if isinstance(self._executor, ProcessPoolExecutor) and progress_handler is not None:
logger.debug(f"a progress_handler={progress_handler} was passed, it must be pickleable to support ProcessPoolExecutor")
progress_handler = progress_handler or self._progress_handler
with self._task_count_lock: with self._task_count_lock:
self._task_count += 1 self._task_count += 1
prompt_id = prompt_id or str(uuid.uuid4()) prompt_id = prompt_id or str(uuid.uuid4())
@ -233,7 +241,7 @@ class Comfy:
client_id, client_id,
carrier, carrier,
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors # todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler, None if isinstance(self._executor, ProcessPoolExecutor) else progress_handler,
self._configuration, self._configuration,
partial_execution_targets, partial_execution_targets,
) )