# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import # RPC server for ModelPatcher isolation (child process) from __future__ import annotations import asyncio import gc import logging import threading import time from dataclasses import dataclass, field from typing import Any, Optional, List try: from comfy.model_patcher import AutoPatcherEjector except ImportError: class AutoPatcherEjector: def __init__(self, model, skip_and_inject_on_exit_only=False): self.model = model self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only self.prev_skip_injection = False self.was_injected = False def __enter__(self): self.was_injected = False self.prev_skip_injection = self.model.skip_injection if self.skip_and_inject_on_exit_only: self.model.skip_injection = True if self.model.is_injected: self.model.eject_model() self.was_injected = True def __exit__(self, *args): if self.skip_and_inject_on_exit_only: self.model.skip_injection = self.prev_skip_injection self.model.inject_model() if self.was_injected and not self.model.skip_injection: self.model.inject_model() self.model.skip_injection = self.prev_skip_injection from comfy.isolation.proxies.base import ( BaseRegistry, detach_if_grad, ) logger = logging.getLogger(__name__) @dataclass class _OperationState: lease: threading.Lock = field(default_factory=threading.Lock) active_count: int = 0 active_by_method: dict[str, int] = field(default_factory=dict) total_operations: int = 0 last_method: Optional[str] = None last_started_ts: Optional[float] = None last_ended_ts: Optional[float] = None last_elapsed_ms: Optional[float] = None last_error: Optional[str] = None last_thread_id: Optional[int] = None last_loop_id: Optional[int] = None class ModelPatcherRegistry(BaseRegistry[Any]): _type_prefix = "model" def __init__(self) -> None: super().__init__() self._pending_cleanup_ids: set[str] = set() self._operation_states: dict[str, _OperationState] = {} self._operation_state_cv = threading.Condition(self._lock) def _get_or_create_operation_state(self, instance_id: str) -> _OperationState: state = self._operation_states.get(instance_id) if state is None: state = _OperationState() self._operation_states[instance_id] = state return state def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]: start_epoch = time.time() start_perf = time.perf_counter() with self._operation_state_cv: state = self._get_or_create_operation_state(instance_id) state.active_count += 1 state.active_by_method[method_name] = ( state.active_by_method.get(method_name, 0) + 1 ) state.total_operations += 1 state.last_method = method_name state.last_started_ts = start_epoch state.last_thread_id = threading.get_ident() try: state.last_loop_id = id(asyncio.get_running_loop()) except RuntimeError: state.last_loop_id = None logger.debug( "ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s", instance_id, method_name, start_epoch, threading.get_ident(), state.last_loop_id, ) return start_epoch, start_perf def _end_operation( self, instance_id: str, method_name: str, start_perf: float, error: Optional[BaseException] = None, ) -> None: end_epoch = time.time() elapsed_ms = (time.perf_counter() - start_perf) * 1000.0 with self._operation_state_cv: state = self._get_or_create_operation_state(instance_id) state.active_count = max(0, state.active_count - 1) if method_name in state.active_by_method: remaining = state.active_by_method[method_name] - 1 if remaining <= 0: state.active_by_method.pop(method_name, None) else: state.active_by_method[method_name] = remaining state.last_ended_ts = end_epoch state.last_elapsed_ms = elapsed_ms state.last_error = None if error is None else repr(error) if state.active_count == 0: self._operation_state_cv.notify_all() logger.debug( "ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s", instance_id, method_name, end_epoch, elapsed_ms, None if error is None else type(error).__name__, ) def _run_operation_with_lease(self, instance_id: str, method_name: str, fn): with self._operation_state_cv: state = self._get_or_create_operation_state(instance_id) lease = state.lease with lease: _, start_perf = self._begin_operation(instance_id, method_name) try: result = fn() except Exception as exc: self._end_operation(instance_id, method_name, start_perf, error=exc) raise self._end_operation(instance_id, method_name, start_perf) return result def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]: with self._operation_state_cv: state = self._operation_states.get(instance_id) if state is None: return { "instance_id": instance_id, "active_count": 0, "active_methods": [], "total_operations": 0, "last_method": None, "last_started_ts": None, "last_ended_ts": None, "last_elapsed_ms": None, "last_error": None, "last_thread_id": None, "last_loop_id": None, } return { "instance_id": instance_id, "active_count": state.active_count, "active_methods": sorted(state.active_by_method.keys()), "total_operations": state.total_operations, "last_method": state.last_method, "last_started_ts": state.last_started_ts, "last_ended_ts": state.last_ended_ts, "last_elapsed_ms": state.last_elapsed_ms, "last_error": state.last_error, "last_thread_id": state.last_thread_id, "last_loop_id": state.last_loop_id, } def unregister_sync(self, instance_id: str) -> None: with self._operation_state_cv: instance = self._registry.pop(instance_id, None) if instance is not None: self._id_map.pop(id(instance), None) self._pending_cleanup_ids.discard(instance_id) self._operation_states.pop(instance_id, None) self._operation_state_cv.notify_all() async def get_operation_state(self, instance_id: str) -> dict[str, Any]: return self._snapshot_operation_state(instance_id) async def get_all_operation_states(self) -> dict[str, dict[str, Any]]: with self._operation_state_cv: ids = sorted(self._operation_states.keys()) return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids} async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool: timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) deadline = None if timeout_s is None else (time.monotonic() + timeout_s) with self._operation_state_cv: while True: active = self._operation_states.get(instance_id) if active is None or active.active_count == 0: return True if deadline is None: self._operation_state_cv.wait() continue remaining = deadline - time.monotonic() if remaining <= 0: return False self._operation_state_cv.wait(timeout=remaining) async def wait_all_idle(self, timeout_ms: int = 0) -> bool: timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0) deadline = None if timeout_s is None else (time.monotonic() + timeout_s) with self._operation_state_cv: while True: has_active = any( state.active_count > 0 for state in self._operation_states.values() ) if not has_active: return True if deadline is None: self._operation_state_cv.wait() continue remaining = deadline - time.monotonic() if remaining <= 0: return False self._operation_state_cv.wait(timeout=remaining) async def clone(self, instance_id: str) -> str: instance = self._get_instance(instance_id) new_model = instance.clone() return self.register(new_model) async def is_clone(self, instance_id: str, other: Any) -> bool: instance = self._get_instance(instance_id) if hasattr(other, "model"): return instance.is_clone(other) return False async def get_model_object(self, instance_id: str, name: str) -> Any: print(f"GP_DEBUG: get_model_object START for name={name}", flush=True) instance = self._get_instance(instance_id) if name == "model": print(f"GP_DEBUG: get_model_object END for name={name} (ModelObject)", flush=True) return f"" result = instance.get_model_object(name) if name == "model_sampling": from comfy.isolation.model_sampling_proxy import ( ModelSamplingRegistry, ModelSamplingProxy, ) registry = ModelSamplingRegistry() # Preserve identity when upstream already returned a proxy. Re-registering # a proxy object creates proxy-of-proxy call chains. if isinstance(result, ModelSamplingProxy): sampling_id = result._instance_id else: sampling_id = registry.register(result) print(f"GP_DEBUG: get_model_object END for name={name} (model_sampling)", flush=True) return ModelSamplingProxy(sampling_id, registry) print(f"GP_DEBUG: get_model_object END for name={name} (fallthrough)", flush=True) return detach_if_grad(result) async def get_model_options(self, instance_id: str) -> dict: instance = self._get_instance(instance_id) import copy opts = copy.deepcopy(instance.model_options) return self._sanitize_rpc_result(opts) async def set_model_options(self, instance_id: str, options: dict) -> None: self._get_instance(instance_id).model_options = options async def get_patcher_attr(self, instance_id: str, name: str) -> Any: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id), name, None) ) async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any: instance = self._get_instance(instance_id) sd_keys = instance.model.state_dict().keys() return dict.fromkeys(sd_keys, None) def _sanitize_rpc_result(self, obj, seen=None): if seen is None: seen = set() if obj is None: return None if isinstance(obj, (bool, int, float, str)): if isinstance(obj, str) and len(obj) > 500000: return f"" return obj obj_id = id(obj) if obj_id in seen: return None seen.add(obj_id) if isinstance(obj, (list, tuple)): return [self._sanitize_rpc_result(x, seen) for x in obj] if isinstance(obj, set): return [self._sanitize_rpc_result(x, seen) for x in obj] if isinstance(obj, dict): new_dict = {} for k, v in obj.items(): if isinstance(k, tuple): import json try: key_str = "__pyisolate_key__" + json.dumps(list(k)) new_dict[key_str] = self._sanitize_rpc_result(v, seen) except Exception: new_dict[str(k)] = self._sanitize_rpc_result(v, seen) else: new_dict[str(k)] = self._sanitize_rpc_result(v, seen) return new_dict if ( hasattr(obj, "__dict__") and not hasattr(obj, "__get__") and not hasattr(obj, "__call__") ): return self._sanitize_rpc_result(obj.__dict__, seen) if hasattr(obj, "items") and hasattr(obj, "get"): return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()} return None async def get_load_device(self, instance_id: str) -> Any: return self._get_instance(instance_id).load_device async def get_offload_device(self, instance_id: str) -> Any: return self._get_instance(instance_id).offload_device async def current_loaded_device(self, instance_id: str) -> Any: return self._get_instance(instance_id).current_loaded_device() async def get_size(self, instance_id: str) -> int: return self._get_instance(instance_id).size async def model_size(self, instance_id: str) -> Any: return self._get_instance(instance_id).model_size() async def loaded_size(self, instance_id: str) -> Any: return self._get_instance(instance_id).loaded_size() async def get_ram_usage(self, instance_id: str) -> int: return self._get_instance(instance_id).get_ram_usage() async def lowvram_patch_counter(self, instance_id: str) -> int: return self._get_instance(instance_id).lowvram_patch_counter() async def memory_required(self, instance_id: str, input_shape: Any) -> Any: return self._run_operation_with_lease( instance_id, "memory_required", lambda: self._get_instance(instance_id).memory_required(input_shape), ) async def is_dynamic(self, instance_id: str) -> bool: instance = self._get_instance(instance_id) if hasattr(instance, "is_dynamic"): return bool(instance.is_dynamic()) return False async def get_free_memory(self, instance_id: str, device: Any) -> Any: instance = self._get_instance(instance_id) if hasattr(instance, "get_free_memory"): return instance.get_free_memory(device) import comfy.model_management return comfy.model_management.get_free_memory(device) async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any: instance = self._get_instance(instance_id) if hasattr(instance, "partially_unload_ram"): return instance.partially_unload_ram(ram_to_unload) return None async def model_dtype(self, instance_id: str) -> Any: return self._run_operation_with_lease( instance_id, "model_dtype", lambda: self._get_instance(instance_id).model_dtype(), ) async def model_patches_to(self, instance_id: str, device: Any) -> Any: return self._get_instance(instance_id).model_patches_to(device) async def partially_load( self, instance_id: str, device: Any, extra_memory: Any, force_patch_weights: bool = False, ) -> Any: return self._run_operation_with_lease( instance_id, "partially_load", lambda: self._get_instance(instance_id).partially_load( device, extra_memory, force_patch_weights=force_patch_weights ), ) async def partially_unload( self, instance_id: str, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False, ) -> int: return self._run_operation_with_lease( instance_id, "partially_unload", lambda: self._get_instance(instance_id).partially_unload( device_to, memory_to_free, force_patch_weights ), ) async def load( self, instance_id: str, device_to: Any = None, lowvram_model_memory: int = 0, force_patch_weights: bool = False, full_load: bool = False, ) -> None: self._run_operation_with_lease( instance_id, "load", lambda: self._get_instance(instance_id).load( device_to, lowvram_model_memory, force_patch_weights, full_load ), ) async def patch_model( self, instance_id: str, device_to: Any = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False, ) -> None: def _invoke() -> None: try: self._get_instance(instance_id).patch_model( device_to, lowvram_model_memory, load_weights, force_patch_weights ) except AttributeError as e: logger.error( f"Isolation Error: Failed to patch model attribute: {e}. Skipping." ) return self._run_operation_with_lease(instance_id, "patch_model", _invoke) async def unpatch_model( self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True ) -> None: self._run_operation_with_lease( instance_id, "unpatch_model", lambda: self._get_instance(instance_id).unpatch_model( device_to, unpatch_weights ), ) async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: self._get_instance(instance_id).detach(unpatch_all) async def prepare_state(self, instance_id: str, timestep: Any) -> Any: instance = self._get_instance(instance_id) cp = getattr(instance.model, "current_patcher", instance) if cp is None: cp = instance return cp.prepare_state(timestep) async def pre_run(self, instance_id: str) -> None: self._get_instance(instance_id).pre_run() async def cleanup(self, instance_id: str) -> None: def _invoke() -> None: try: instance = self._get_instance(instance_id) except Exception: logger.debug( "ModelPatcher cleanup requested for missing instance %s", instance_id, exc_info=True, ) return try: instance.cleanup() finally: with self._lock: self._pending_cleanup_ids.add(instance_id) gc.collect() self._run_operation_with_lease(instance_id, "cleanup", _invoke) def sweep_pending_cleanup(self) -> int: removed = 0 with self._operation_state_cv: pending_ids = list(self._pending_cleanup_ids) self._pending_cleanup_ids.clear() for instance_id in pending_ids: instance = self._registry.pop(instance_id, None) if instance is None: continue self._id_map.pop(id(instance), None) self._operation_states.pop(instance_id, None) removed += 1 self._operation_state_cv.notify_all() gc.collect() return removed def purge_all(self) -> int: with self._operation_state_cv: removed = len(self._registry) self._registry.clear() self._id_map.clear() self._pending_cleanup_ids.clear() self._operation_states.clear() self._operation_state_cv.notify_all() gc.collect() return removed async def apply_hooks(self, instance_id: str, hooks: Any) -> Any: instance = self._get_instance(instance_id) cp = getattr(instance.model, "current_patcher", instance) if cp is None: cp = instance return cp.apply_hooks(hooks=hooks) async def clean_hooks(self, instance_id: str) -> None: self._get_instance(instance_id).clean_hooks() async def restore_hook_patches(self, instance_id: str) -> None: self._get_instance(instance_id).restore_hook_patches() async def unpatch_hooks( self, instance_id: str, whitelist_keys_set: Optional[set] = None ) -> None: self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set) async def register_all_hook_patches( self, instance_id: str, hooks: Any, target_dict: Any, model_options: Any, registered: Any, ) -> None: from types import SimpleNamespace import comfy.hooks instance = self._get_instance(instance_id) if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"): hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks new_hooks = comfy.hooks.HookGroup() if hasattr(hook_data, "hooks"): new_hooks.hooks = ( hook_data["hooks"] if isinstance(hook_data, dict) else hook_data.hooks ) hooks = new_hooks instance.register_all_hook_patches( hooks, target_dict, model_options, registered ) async def get_hook_mode(self, instance_id: str) -> Any: return getattr(self._get_instance(instance_id), "hook_mode", None) async def set_hook_mode(self, instance_id: str, value: Any) -> None: setattr(self._get_instance(instance_id), "hook_mode", value) async def inject_model(self, instance_id: str) -> None: instance = self._get_instance(instance_id) try: instance.inject_model() except AttributeError as e: if "inject" in str(e): logger.error( "Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping." ) return raise e async def eject_model(self, instance_id: str) -> None: self._get_instance(instance_id).eject_model() async def get_is_injected(self, instance_id: str) -> bool: return self._get_instance(instance_id).is_injected async def set_skip_injection(self, instance_id: str, value: bool) -> None: self._get_instance(instance_id).skip_injection = value async def get_skip_injection(self, instance_id: str) -> bool: return self._get_instance(instance_id).skip_injection async def set_model_sampler_cfg_function( self, instance_id: str, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False, ) -> None: if not callable(sampler_cfg_function): logger.error( f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping." ) return self._get_instance(instance_id).set_model_sampler_cfg_function( sampler_cfg_function, disable_cfg1_optimization ) async def set_model_sampler_post_cfg_function( self, instance_id: str, post_cfg_function: Any, disable_cfg1_optimization: bool = False, ) -> None: self._get_instance(instance_id).set_model_sampler_post_cfg_function( post_cfg_function, disable_cfg1_optimization ) async def set_model_sampler_pre_cfg_function( self, instance_id: str, pre_cfg_function: Any, disable_cfg1_optimization: bool = False, ) -> None: self._get_instance(instance_id).set_model_sampler_pre_cfg_function( pre_cfg_function, disable_cfg1_optimization ) async def set_model_sampler_calc_cond_batch_function( self, instance_id: str, fn: Any ) -> None: self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn) async def set_model_unet_function_wrapper( self, instance_id: str, unet_wrapper_function: Any ) -> None: self._get_instance(instance_id).set_model_unet_function_wrapper( unet_wrapper_function ) async def set_model_denoise_mask_function( self, instance_id: str, denoise_mask_function: Any ) -> None: self._get_instance(instance_id).set_model_denoise_mask_function( denoise_mask_function ) async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None: self._get_instance(instance_id).set_model_patch(patch, name) async def set_model_patch_replace( self, instance_id: str, patch: Any, name: str, block_name: str, number: int, transformer_index: Optional[int] = None, ) -> None: self._get_instance(instance_id).set_model_patch_replace( patch, name, block_name, number, transformer_index ) async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None: self._get_instance(instance_id).set_model_input_block_patch(patch) async def set_model_input_block_patch_after_skip( self, instance_id: str, patch: Any ) -> None: self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch) async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None: self._get_instance(instance_id).set_model_output_block_patch(patch) async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None: self._get_instance(instance_id).set_model_emb_patch(patch) async def set_model_forward_timestep_embed_patch( self, instance_id: str, patch: Any ) -> None: self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch) async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None: self._get_instance(instance_id).set_model_double_block_patch(patch) async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None: self._get_instance(instance_id).set_model_post_input_patch(patch) async def set_model_rope_options(self, instance_id: str, options: dict) -> None: self._get_instance(instance_id).set_model_rope_options(**options) async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None: self._get_instance(instance_id).set_model_compute_dtype(dtype) async def clone_has_same_weights_by_id( self, instance_id: str, other_id: str ) -> bool: instance = self._get_instance(instance_id) other = self._get_instance(other_id) if not other: return False return instance.clone_has_same_weights(other) async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any: return self._get_instance(instance_id)._load_list(*args, **kwargs) async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool: instance = self._get_instance(instance_id) other = self._get_instance(other_id) if hasattr(instance, "is_clone"): return instance.is_clone(other) return False async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None: self._get_instance(instance_id).add_object_patch(name, obj) async def add_weight_wrapper( self, instance_id: str, name: str, function: Any ) -> None: self._get_instance(instance_id).add_weight_wrapper(name, function) async def add_wrapper_with_key( self, instance_id: str, wrapper_type: Any, key: str, fn: Any ) -> None: self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn) async def remove_wrappers_with_key( self, instance_id: str, wrapper_type: str, key: str ) -> None: self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key) async def get_wrappers( self, instance_id: str, wrapper_type: str = None, key: str = None ) -> Any: if wrapper_type is None and key is None: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id), "wrappers", {}) ) return self._sanitize_rpc_result( self._get_instance(instance_id).get_wrappers(wrapper_type, key) ) async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])( wrapper_type ) ) async def add_callback_with_key( self, instance_id: str, call_type: str, key: str, callback: Any ) -> None: self._get_instance(instance_id).add_callback_with_key(call_type, key, callback) async def remove_callbacks_with_key( self, instance_id: str, call_type: str, key: str ) -> None: self._get_instance(instance_id).remove_callbacks_with_key(call_type, key) async def get_callbacks( self, instance_id: str, call_type: str = None, key: str = None ) -> Any: if call_type is None and key is None: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id), "callbacks", {}) ) return self._sanitize_rpc_result( self._get_instance(instance_id).get_callbacks(call_type, key) ) async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])( call_type ) ) async def set_attachments( self, instance_id: str, key: str, attachment: Any ) -> None: self._get_instance(instance_id).set_attachments(key, attachment) async def get_attachment(self, instance_id: str, key: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).get_attachment(key) ) async def remove_attachments(self, instance_id: str, key: str) -> None: self._get_instance(instance_id).remove_attachments(key) async def set_injections(self, instance_id: str, key: str, injections: Any) -> None: self._get_instance(instance_id).set_injections(key, injections) async def get_injections(self, instance_id: str, key: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).get_injections(key) ) async def remove_injections(self, instance_id: str, key: str) -> None: self._get_instance(instance_id).remove_injections(key) async def set_additional_models( self, instance_id: str, key: str, models: Any ) -> None: self._get_instance(instance_id).set_additional_models(key, models) async def remove_additional_models(self, instance_id: str, key: str) -> None: self._get_instance(instance_id).remove_additional_models(key) async def get_nested_additional_models(self, instance_id: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).get_nested_additional_models() ) async def get_additional_models(self, instance_id: str) -> List[str]: models = self._get_instance(instance_id).get_additional_models() return [self.register(m) for m in models] async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).get_additional_models_with_key(key) ) async def model_patches_models(self, instance_id: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).model_patches_models() ) async def get_patches(self, instance_id: str) -> Any: return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy()) async def get_object_patches(self, instance_id: str) -> Any: return self._sanitize_rpc_result( self._get_instance(instance_id).object_patches.copy() ) async def add_patches( self, instance_id: str, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0, ) -> Any: return self._get_instance(instance_id).add_patches( patches, strength_patch, strength_model ) async def get_key_patches( self, instance_id: str, filter_prefix: Optional[str] = None ) -> Any: res = self._get_instance(instance_id).get_key_patches() if filter_prefix: res = {k: v for k, v in res.items() if k.startswith(filter_prefix)} safe_res = {} for k, v in res.items(): safe_res[k] = [ f"" if hasattr(t, "shape") else str(t) for t in v ] return safe_res async def add_hook_patches( self, instance_id: str, hook: Any, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0, ) -> None: if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict): try: hook.hook_ref = tuple(sorted(hook.hook_ref.items())) except Exception: hook.hook_ref = None self._get_instance(instance_id).add_hook_patches( hook, patches, strength_patch, strength_model ) async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any: if hooks is not None and hasattr(hooks, "hooks"): for hook in getattr(hooks, "hooks", []): hook_ref = getattr(hook, "hook_ref", None) if isinstance(hook_ref, dict): try: hook.hook_ref = tuple(sorted(hook_ref.items())) except Exception: hook.hook_ref = None res = self._get_instance(instance_id).get_combined_hook_patches(hooks) return self._sanitize_rpc_result(res) async def clear_cached_hook_weights(self, instance_id: str) -> None: self._get_instance(instance_id).clear_cached_hook_weights() async def prepare_hook_patches_current_keyframe( self, instance_id: str, t: Any, hook_group: Any, model_options: Any ) -> None: self._get_instance(instance_id).prepare_hook_patches_current_keyframe( t, hook_group, model_options ) async def get_parent(self, instance_id: str) -> Any: return getattr(self._get_instance(instance_id), "parent", None) async def patch_weight_to_device( self, instance_id: str, key: str, device_to: Any = None, inplace_update: bool = False, ) -> None: self._get_instance(instance_id).patch_weight_to_device( key, device_to, inplace_update ) async def pin_weight_to_device(self, instance_id: str, key: str) -> None: instance = self._get_instance(instance_id) if hasattr(instance, "pinned") and isinstance(instance.pinned, list): instance.pinned = set(instance.pinned) instance.pin_weight_to_device(key) async def unpin_weight(self, instance_id: str, key: str) -> None: instance = self._get_instance(instance_id) if hasattr(instance, "pinned") and isinstance(instance.pinned, list): instance.pinned = set(instance.pinned) instance.unpin_weight(key) async def unpin_all_weights(self, instance_id: str) -> None: instance = self._get_instance(instance_id) if hasattr(instance, "pinned") and isinstance(instance.pinned, list): instance.pinned = set(instance.pinned) instance.unpin_all_weights() async def calculate_weight( self, instance_id: str, patches: Any, weight: Any, key: str, intermediate_dtype: Any = float, ) -> Any: return detach_if_grad( self._get_instance(instance_id).calculate_weight( patches, weight, key, intermediate_dtype ) ) async def get_inner_model_attr(self, instance_id: str, name: str) -> Any: try: return self._sanitize_rpc_result( getattr(self._get_instance(instance_id).model, name) ) except AttributeError: return None async def inner_model_memory_required( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: return self._run_operation_with_lease( instance_id, "inner_model_memory_required", lambda: self._get_instance(instance_id).model.memory_required( *args, **kwargs ), ) async def inner_model_extra_conds_shapes( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: return self._run_operation_with_lease( instance_id, "inner_model_extra_conds_shapes", lambda: self._get_instance(instance_id).model.extra_conds_shapes( *args, **kwargs ), ) async def inner_model_extra_conds( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: def _invoke() -> Any: result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs) try: import torch import comfy.conds except Exception: return result def _to_cpu(obj: Any) -> Any: if torch.is_tensor(obj): return obj.detach().cpu() if obj.device.type != "cpu" else obj if isinstance(obj, dict): return {k: _to_cpu(v) for k, v in obj.items()} if isinstance(obj, list): return [_to_cpu(v) for v in obj] if isinstance(obj, tuple): return tuple(_to_cpu(v) for v in obj) if isinstance(obj, comfy.conds.CONDRegular): return type(obj)(_to_cpu(obj.cond)) return obj return _to_cpu(result) return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke) async def inner_model_state_dict( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs) return { k: {"numel": v.numel(), "element_size": v.element_size()} for k, v in sd.items() } async def inner_model_apply_model( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: def _invoke() -> Any: import torch instance = self._get_instance(instance_id) target = getattr(instance, "load_device", None) if target is None and args and hasattr(args[0], "device"): target = args[0].device elif target is None: for v in kwargs.values(): if hasattr(v, "device"): target = v.device break def _move(obj): if target is None: return obj if isinstance(obj, (tuple, list)): return type(obj)(_move(o) for o in obj) if hasattr(obj, "to"): return obj.to(target) return obj moved_args = tuple(_move(a) for a in args) moved_kwargs = {k: _move(v) for k, v in kwargs.items()} result = instance.model.apply_model(*moved_args, **moved_kwargs) moved_result = detach_if_grad(_move(result)) # DynamicVRAM + isolation: returning CUDA tensors across RPC can stall # at the transport boundary. Marshal dynamic-path results as CPU and let # the proxy restore device placement in the child process. is_dynamic_fn = getattr(instance, "is_dynamic", None) if callable(is_dynamic_fn) and is_dynamic_fn(): def _to_cpu(obj: Any) -> Any: if torch.is_tensor(obj): return obj.detach().cpu() if obj.device.type != "cpu" else obj if isinstance(obj, dict): return {k: _to_cpu(v) for k, v in obj.items()} if isinstance(obj, list): return [_to_cpu(v) for v in obj] if isinstance(obj, tuple): return tuple(_to_cpu(v) for v in obj) return obj return _to_cpu(moved_result) return moved_result return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke) async def process_latent_in( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: return self._run_operation_with_lease( instance_id, "process_latent_in", lambda: detach_if_grad( self._get_instance(instance_id).model.process_latent_in( *args, **kwargs ) ), ) async def process_latent_out( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: import torch def _invoke() -> Any: instance = self._get_instance(instance_id) result = instance.model.process_latent_out(*args, **kwargs) moved_result = None try: target = None if args and hasattr(args[0], "device"): target = args[0].device elif kwargs: for v in kwargs.values(): if hasattr(v, "device"): target = v.device break if target is not None and hasattr(result, "to"): moved_result = detach_if_grad(result.to(target)) except Exception: logger.debug( "process_latent_out: failed to move result to target device", exc_info=True, ) if moved_result is None: moved_result = detach_if_grad(result) is_dynamic_fn = getattr(instance, "is_dynamic", None) if callable(is_dynamic_fn) and is_dynamic_fn(): def _to_cpu(obj: Any) -> Any: if torch.is_tensor(obj): return obj.detach().cpu() if obj.device.type != "cpu" else obj if isinstance(obj, dict): return {k: _to_cpu(v) for k, v in obj.items()} if isinstance(obj, list): return [_to_cpu(v) for v in obj] if isinstance(obj, tuple): return tuple(_to_cpu(v) for v in obj) return obj return _to_cpu(moved_result) return moved_result return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke) async def scale_latent_inpaint( self, instance_id: str, args: tuple, kwargs: dict ) -> Any: import torch def _invoke() -> Any: instance = self._get_instance(instance_id) result = instance.model.scale_latent_inpaint(*args, **kwargs) moved_result = None try: target = None if args and hasattr(args[0], "device"): target = args[0].device elif kwargs: for v in kwargs.values(): if hasattr(v, "device"): target = v.device break if target is not None and hasattr(result, "to"): moved_result = detach_if_grad(result.to(target)) except Exception: logger.debug( "scale_latent_inpaint: failed to move result to target device", exc_info=True, ) if moved_result is None: moved_result = detach_if_grad(result) is_dynamic_fn = getattr(instance, "is_dynamic", None) if callable(is_dynamic_fn) and is_dynamic_fn(): def _to_cpu(obj: Any) -> Any: if torch.is_tensor(obj): return obj.detach().cpu() if obj.device.type != "cpu" else obj if isinstance(obj, dict): return {k: _to_cpu(v) for k, v in obj.items()} if isinstance(obj, list): return [_to_cpu(v) for v in obj] if isinstance(obj, tuple): return tuple(_to_cpu(v) for v in obj) return obj return _to_cpu(moved_result) return moved_result return self._run_operation_with_lease( instance_id, "scale_latent_inpaint", _invoke ) async def load_lora( self, instance_id: str, lora_path: str, strength_model: float, clip_id: Optional[str] = None, strength_clip: float = 1.0, ) -> dict: import comfy.utils import comfy.sd import folder_paths from comfy.isolation.clip_proxy import CLIPRegistry model = self._get_instance(instance_id) clip = None if clip_id: clip = CLIPRegistry()._get_instance(clip_id) lora_full_path = folder_paths.get_full_path("loras", lora_path) if lora_full_path is None: raise ValueError(f"LoRA file not found: {lora_path}") lora = comfy.utils.load_torch_file(lora_full_path) new_model, new_clip = comfy.sd.load_lora_for_models( model, clip, lora, strength_model, strength_clip ) new_model_id = self.register(new_model) if new_model else None new_clip_id = ( CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None ) return {"model_id": new_model_id, "clip_id": new_clip_id}