# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,raise-missing-from,useless-return,wrong-import-position from __future__ import annotations import logging import os from pathlib import Path from typing import Any, Callable, Dict, List, Optional from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped] from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped] try: from comfy.isolation.clip_proxy import CLIPProxy, CLIPRegistry from comfy.isolation.model_patcher_proxy import ( ModelPatcherProxy, ModelPatcherRegistry, ) from comfy.isolation.model_sampling_proxy import ( ModelSamplingProxy, ModelSamplingRegistry, ) from comfy.isolation.vae_proxy import VAEProxy, VAERegistry, FirstStageModelRegistry from comfy.isolation.proxies.folder_paths_proxy import FolderPathsProxy from comfy.isolation.proxies.model_management_proxy import ModelManagementProxy from comfy.isolation.proxies.prompt_server_impl import PromptServerService from comfy.isolation.proxies.utils_proxy import UtilsProxy from comfy.isolation.proxies.progress_proxy import ProgressProxy except ImportError as exc: # Fail loud if Comfy environment is incomplete raise ImportError(f"ComfyUI environment incomplete: {exc}") logger = logging.getLogger(__name__) # Force /dev/shm for shared memory (bwrap makes /tmp private) import tempfile if os.path.exists("/dev/shm"): # Only override if not already set or if default is not /dev/shm current_tmp = tempfile.gettempdir() if not current_tmp.startswith("/dev/shm"): logger.debug( f"Configuring shared memory: Changing TMPDIR from {current_tmp} to /dev/shm" ) os.environ["TMPDIR"] = "/dev/shm" tempfile.tempdir = None # Clear cache to force re-evaluation class ComfyUIAdapter(IsolationAdapter): # ComfyUI-specific IsolationAdapter implementation @property def identifier(self) -> str: return "comfyui" def get_path_config(self, module_path: str) -> Optional[Dict[str, Any]]: if "ComfyUI" in module_path and "custom_nodes" in module_path: parts = module_path.split("ComfyUI") if len(parts) > 1: comfy_root = parts[0] + "ComfyUI" return { "preferred_root": comfy_root, "additional_paths": [ os.path.join(comfy_root, "custom_nodes"), os.path.join(comfy_root, "comfy"), ], } return None def setup_child_environment(self, snapshot: Dict[str, Any]) -> None: comfy_root = snapshot.get("preferred_root") if not comfy_root: return requirements_path = Path(comfy_root) / "requirements.txt" if requirements_path.exists(): import re for line in requirements_path.read_text().splitlines(): line = line.strip() if not line or line.startswith("#"): continue pkg_name = re.split(r"[<>=!~\[]", line)[0].strip() if pkg_name: logging.getLogger(pkg_name).setLevel(logging.ERROR) def register_serializers(self, registry: SerializerRegistryProtocol) -> None: import torch def serialize_device(obj: Any) -> Dict[str, Any]: return {"__type__": "device", "device_str": str(obj)} def deserialize_device(data: Dict[str, Any]) -> Any: return torch.device(data["device_str"]) registry.register("device", serialize_device, deserialize_device) _VALID_DTYPES = { "float16", "float32", "float64", "bfloat16", "int8", "int16", "int32", "int64", "uint8", "bool", } def serialize_dtype(obj: Any) -> Dict[str, Any]: return {"__type__": "dtype", "dtype_str": str(obj)} def deserialize_dtype(data: Dict[str, Any]) -> Any: dtype_name = data["dtype_str"].replace("torch.", "") if dtype_name not in _VALID_DTYPES: raise ValueError(f"Invalid dtype: {data['dtype_str']}") return getattr(torch, dtype_name) registry.register("dtype", serialize_dtype, deserialize_dtype) def serialize_model_patcher(obj: Any) -> Dict[str, Any]: # Child-side: must already have _instance_id (proxy) if os.environ.get("PYISOLATE_CHILD") == "1": if hasattr(obj, "_instance_id"): return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} raise RuntimeError( f"ModelPatcher in child lacks _instance_id: " f"{type(obj).__module__}.{type(obj).__name__}" ) # Host-side: register with registry if hasattr(obj, "_instance_id"): return {"__type__": "ModelPatcherRef", "model_id": obj._instance_id} model_id = ModelPatcherRegistry().register(obj) return {"__type__": "ModelPatcherRef", "model_id": model_id} def deserialize_model_patcher(data: Any) -> Any: """Deserialize ModelPatcher refs; pass through already-materialized objects.""" if isinstance(data, dict): return ModelPatcherProxy( data["model_id"], registry=None, manage_lifecycle=False ) return data def deserialize_model_patcher_ref(data: Dict[str, Any]) -> Any: """Context-aware ModelPatcherRef deserializer for both host and child.""" is_child = os.environ.get("PYISOLATE_CHILD") == "1" if is_child: return ModelPatcherProxy( data["model_id"], registry=None, manage_lifecycle=False ) else: return ModelPatcherRegistry()._get_instance(data["model_id"]) # Register ModelPatcher type for serialization registry.register( "ModelPatcher", serialize_model_patcher, deserialize_model_patcher ) # Register ModelPatcherProxy type (already a proxy, just return ref) registry.register( "ModelPatcherProxy", serialize_model_patcher, deserialize_model_patcher ) # Register ModelPatcherRef for deserialization (context-aware: host or child) registry.register("ModelPatcherRef", None, deserialize_model_patcher_ref) def serialize_clip(obj: Any) -> Dict[str, Any]: if hasattr(obj, "_instance_id"): return {"__type__": "CLIPRef", "clip_id": obj._instance_id} clip_id = CLIPRegistry().register(obj) return {"__type__": "CLIPRef", "clip_id": clip_id} def deserialize_clip(data: Any) -> Any: if isinstance(data, dict): return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) return data def deserialize_clip_ref(data: Dict[str, Any]) -> Any: """Context-aware CLIPRef deserializer for both host and child.""" is_child = os.environ.get("PYISOLATE_CHILD") == "1" if is_child: return CLIPProxy(data["clip_id"], registry=None, manage_lifecycle=False) else: return CLIPRegistry()._get_instance(data["clip_id"]) # Register CLIP type for serialization registry.register("CLIP", serialize_clip, deserialize_clip) # Register CLIPProxy type (already a proxy, just return ref) registry.register("CLIPProxy", serialize_clip, deserialize_clip) # Register CLIPRef for deserialization (context-aware: host or child) registry.register("CLIPRef", None, deserialize_clip_ref) def serialize_vae(obj: Any) -> Dict[str, Any]: if hasattr(obj, "_instance_id"): return {"__type__": "VAERef", "vae_id": obj._instance_id} vae_id = VAERegistry().register(obj) return {"__type__": "VAERef", "vae_id": vae_id} def deserialize_vae(data: Any) -> Any: if isinstance(data, dict): return VAEProxy(data["vae_id"]) return data def deserialize_vae_ref(data: Dict[str, Any]) -> Any: """Context-aware VAERef deserializer for both host and child.""" is_child = os.environ.get("PYISOLATE_CHILD") == "1" if is_child: # Child: create a proxy return VAEProxy(data["vae_id"]) else: # Host: lookup real VAE from registry return VAERegistry()._get_instance(data["vae_id"]) # Register VAE type for serialization registry.register("VAE", serialize_vae, deserialize_vae) # Register VAEProxy type (already a proxy, just return ref) registry.register("VAEProxy", serialize_vae, deserialize_vae) # Register VAERef for deserialization (context-aware: host or child) registry.register("VAERef", None, deserialize_vae_ref) # ModelSampling serialization - handles ModelSampling* types # copyreg removed - no pickle fallback allowed def serialize_model_sampling(obj: Any) -> Dict[str, Any]: # Child-side: must already have _instance_id (proxy) if os.environ.get("PYISOLATE_CHILD") == "1": if hasattr(obj, "_instance_id"): return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} raise RuntimeError( f"ModelSampling in child lacks _instance_id: " f"{type(obj).__module__}.{type(obj).__name__}" ) # Host-side pass-through for proxies: do not re-register a proxy as a # new ModelSamplingRef, or we create proxy-of-proxy indirection. if hasattr(obj, "_instance_id"): return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id} # Host-side: register with ModelSamplingRegistry and return JSON-safe dict ms_id = ModelSamplingRegistry().register(obj) return {"__type__": "ModelSamplingRef", "ms_id": ms_id} def deserialize_model_sampling(data: Any) -> Any: """Deserialize ModelSampling refs; pass through already-materialized objects.""" if isinstance(data, dict): return ModelSamplingProxy(data["ms_id"]) return data def deserialize_model_sampling_ref(data: Dict[str, Any]) -> Any: """Context-aware ModelSamplingRef deserializer for both host and child.""" is_child = os.environ.get("PYISOLATE_CHILD") == "1" if is_child: return ModelSamplingProxy(data["ms_id"]) else: return ModelSamplingRegistry()._get_instance(data["ms_id"]) # Register all ModelSampling* and StableCascadeSampling classes dynamically import comfy.model_sampling for ms_cls in vars(comfy.model_sampling).values(): if not isinstance(ms_cls, type): continue if not issubclass(ms_cls, torch.nn.Module): continue if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"): continue registry.register( ms_cls.__name__, serialize_model_sampling, deserialize_model_sampling, ) registry.register( "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling ) # Register ModelSamplingRef for deserialization (context-aware: host or child) registry.register("ModelSamplingRef", None, deserialize_model_sampling_ref) def serialize_cond(obj: Any) -> Dict[str, Any]: type_key = f"{type(obj).__module__}.{type(obj).__name__}" return { "__type__": type_key, "cond": obj.cond, } def deserialize_cond(data: Dict[str, Any]) -> Any: import importlib type_key = data["__type__"] module_name, class_name = type_key.rsplit(".", 1) module = importlib.import_module(module_name) cls = getattr(module, class_name) return cls(data["cond"]) def _serialize_public_state(obj: Any) -> Dict[str, Any]: state: Dict[str, Any] = {} for key, value in obj.__dict__.items(): if key.startswith("_"): continue if callable(value): continue state[key] = value return state def serialize_latent_format(obj: Any) -> Dict[str, Any]: type_key = f"{type(obj).__module__}.{type(obj).__name__}" return { "__type__": type_key, "state": _serialize_public_state(obj), } def deserialize_latent_format(data: Dict[str, Any]) -> Any: import importlib type_key = data["__type__"] module_name, class_name = type_key.rsplit(".", 1) module = importlib.import_module(module_name) cls = getattr(module, class_name) obj = cls() for key, value in data.get("state", {}).items(): prop = getattr(type(obj), key, None) if isinstance(prop, property) and prop.fset is None: continue setattr(obj, key, value) return obj import comfy.conds for cond_cls in vars(comfy.conds).values(): if not isinstance(cond_cls, type): continue if not issubclass(cond_cls, comfy.conds.CONDRegular): continue type_key = f"{cond_cls.__module__}.{cond_cls.__name__}" registry.register(type_key, serialize_cond, deserialize_cond) registry.register(cond_cls.__name__, serialize_cond, deserialize_cond) import comfy.latent_formats for latent_cls in vars(comfy.latent_formats).values(): if not isinstance(latent_cls, type): continue if not issubclass(latent_cls, comfy.latent_formats.LatentFormat): continue type_key = f"{latent_cls.__module__}.{latent_cls.__name__}" registry.register( type_key, serialize_latent_format, deserialize_latent_format ) registry.register( latent_cls.__name__, serialize_latent_format, deserialize_latent_format ) # V3 API: unwrap NodeOutput.args def deserialize_node_output(data: Any) -> Any: return getattr(data, "args", data) registry.register("NodeOutput", None, deserialize_node_output) # KSAMPLER serializer: stores sampler name instead of function object # sampler_function is a callable which gets filtered out by JSONSocketTransport def serialize_ksampler(obj: Any) -> Dict[str, Any]: func_name = obj.sampler_function.__name__ # Map function name back to sampler name if func_name == "sample_unipc": sampler_name = "uni_pc" elif func_name == "sample_unipc_bh2": sampler_name = "uni_pc_bh2" elif func_name == "dpm_fast_function": sampler_name = "dpm_fast" elif func_name == "dpm_adaptive_function": sampler_name = "dpm_adaptive" elif func_name.startswith("sample_"): sampler_name = func_name[7:] # Remove "sample_" prefix else: sampler_name = func_name return { "__type__": "KSAMPLER", "sampler_name": sampler_name, "extra_options": obj.extra_options, "inpaint_options": obj.inpaint_options, } def deserialize_ksampler(data: Dict[str, Any]) -> Any: import comfy.samplers return comfy.samplers.ksampler( data["sampler_name"], data.get("extra_options", {}), data.get("inpaint_options", {}), ) registry.register("KSAMPLER", serialize_ksampler, deserialize_ksampler) from comfy.isolation.model_patcher_proxy_utils import register_hooks_serializers register_hooks_serializers(registry) # Generic Numpy Serializer def serialize_numpy(obj: Any) -> Any: import torch try: # Attempt zero-copy conversion to Tensor return torch.from_numpy(obj) except Exception: # Fallback for non-numeric arrays (strings, objects, mixes) return obj.tolist() registry.register("ndarray", serialize_numpy, None) def serialize_ply(obj: Any) -> Dict[str, Any]: import base64 import torch if obj.raw_data is not None: return { "__type__": "PLY", "raw_data": base64.b64encode(obj.raw_data).decode("ascii"), } result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)} if obj.colors is not None: result["colors"] = torch.from_numpy(obj.colors) if obj.confidence is not None: result["confidence"] = torch.from_numpy(obj.confidence) if obj.view_id is not None: result["view_id"] = torch.from_numpy(obj.view_id) return result def deserialize_ply(data: Any) -> Any: import base64 from comfy_api.latest._util.ply_types import PLY if "raw_data" in data: return PLY(raw_data=base64.b64decode(data["raw_data"])) return PLY( points=data["points"], colors=data.get("colors"), confidence=data.get("confidence"), view_id=data.get("view_id"), ) registry.register("PLY", serialize_ply, deserialize_ply, data_type=True) def serialize_npz(obj: Any) -> Dict[str, Any]: import base64 return { "__type__": "NPZ", "frames": [base64.b64encode(f).decode("ascii") for f in obj.frames], } def deserialize_npz(data: Any) -> Any: import base64 from comfy_api.latest._util.npz_types import NPZ return NPZ(frames=[base64.b64decode(f) for f in data["frames"]]) registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True) def serialize_file3d(obj: Any) -> Dict[str, Any]: import base64 return { "__type__": "File3D", "format": obj.format, "data": base64.b64encode(obj.get_bytes()).decode("ascii"), } def deserialize_file3d(data: Any) -> Any: import base64 from io import BytesIO from comfy_api.latest._util.geometry_types import File3D return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"]) registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True) def serialize_video(obj: Any) -> Dict[str, Any]: components = obj.get_components() images = components.images.detach() if components.images.requires_grad else components.images result: Dict[str, Any] = { "__type__": "VIDEO", "images": images, "frame_rate_num": components.frame_rate.numerator, "frame_rate_den": components.frame_rate.denominator, } if components.audio is not None: waveform = components.audio["waveform"] if waveform.requires_grad: waveform = waveform.detach() result["audio_waveform"] = waveform result["audio_sample_rate"] = components.audio["sample_rate"] if components.metadata is not None: result["metadata"] = components.metadata return result def deserialize_video(data: Any) -> Any: from fractions import Fraction from comfy_api.latest._input_impl.video_types import VideoFromComponents from comfy_api.latest._util.video_types import VideoComponents audio = None if "audio_waveform" in data: audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]} components = VideoComponents( images=data["images"], frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]), audio=audio, metadata=data.get("metadata"), ) return VideoFromComponents(components) registry.register("VIDEO", serialize_video, deserialize_video, data_type=True) registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True) registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True) def provide_rpc_services(self) -> List[type[ProxiedSingleton]]: return [ PromptServerService, FolderPathsProxy, ModelManagementProxy, UtilsProxy, ProgressProxy, VAERegistry, CLIPRegistry, ModelPatcherRegistry, ModelSamplingRegistry, FirstStageModelRegistry, ] def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: # Resolve the real name whether it's an instance or the Singleton class itself api_name = api.__name__ if isinstance(api, type) else api.__class__.__name__ if api_name == "FolderPathsProxy": import folder_paths # Replace module-level functions with proxy methods # This is aggressive but necessary for transparent proxying # Handle both instance and class cases instance = api() if isinstance(api, type) else api for name in dir(instance): if not name.startswith("_"): setattr(folder_paths, name, getattr(instance, name)) # Fence: isolated children get writable temp inside sandbox if os.environ.get("PYISOLATE_CHILD") == "1": _child_temp = os.path.join("/tmp", "comfyui_temp") os.makedirs(_child_temp, exist_ok=True) folder_paths.temp_directory = _child_temp return if api_name == "ModelManagementProxy": import comfy.model_management instance = api() if isinstance(api, type) else api # Replace module-level functions with proxy methods for name in dir(instance): if not name.startswith("_"): setattr(comfy.model_management, name, getattr(instance, name)) return if api_name == "UtilsProxy": import comfy.utils # Static Injection of RPC mechanism to ensure Child can access it # independent of instance lifecycle. api.set_rpc(rpc) # Don't overwrite host hook (infinite recursion) return if api_name == "PromptServerProxy": # Defer heavy import to child context import server instance = api() if isinstance(api, type) else api proxy = ( instance.instance ) # PromptServerProxy instance has .instance property returning self original_register_route = proxy.register_route def register_route_wrapper( method: str, path: str, handler: Callable[..., Any] ) -> None: callback_id = rpc.register_callback(handler) loop = getattr(rpc, "loop", None) if loop and loop.is_running(): import asyncio asyncio.create_task( original_register_route( method, path, handler=callback_id, is_callback=True ) ) else: original_register_route( method, path, handler=callback_id, is_callback=True ) return None proxy.register_route = register_route_wrapper class RouteTableDefProxy: def __init__(self, proxy_instance: Any): self.proxy = proxy_instance def get( self, path: str, **kwargs: Any ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: self.proxy.register_route("GET", path, handler) return handler return decorator def post( self, path: str, **kwargs: Any ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: self.proxy.register_route("POST", path, handler) return handler return decorator def patch( self, path: str, **kwargs: Any ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: self.proxy.register_route("PATCH", path, handler) return handler return decorator def put( self, path: str, **kwargs: Any ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: self.proxy.register_route("PUT", path, handler) return handler return decorator def delete( self, path: str, **kwargs: Any ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: self.proxy.register_route("DELETE", path, handler) return handler return decorator proxy.routes = RouteTableDefProxy(proxy) if ( hasattr(server, "PromptServer") and getattr(server.PromptServer, "instance", None) != proxy ): server.PromptServer.instance = proxy