diff --git a/comfy/isolation/clip_proxy.py b/comfy/isolation/clip_proxy.py new file mode 100644 index 000000000..371665314 --- /dev/null +++ b/comfy/isolation/clip_proxy.py @@ -0,0 +1,327 @@ +# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation +# CLIP Proxy implementation +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) + +if TYPE_CHECKING: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + +class CondStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "cond_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class CondStageModelProxy(BaseProxy[CondStageModelRegistry]): + _registry_class = CondStageModelRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class TokenizerRegistry(BaseRegistry[Any]): + _type_prefix = "tokenizer" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + +class TokenizerProxy(BaseProxy[TokenizerRegistry]): + _registry_class = TokenizerRegistry + __module__ = "comfy.sd" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +logger = logging.getLogger(__name__) + + +class CLIPRegistry(BaseRegistry[Any]): + _type_prefix = "clip" + _allowed_setters = { + "layer_idx", + "tokenizer_options", + "use_clip_schedule", + "apply_hooks_to_conds", + } + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def get_patcher_id(self, instance_id: str) -> str: + from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry + + return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher) + + async def get_cond_stage_model_id(self, instance_id: str) -> str: + return CondStageModelRegistry().register( + self._get_instance(instance_id).cond_stage_model + ) + + async def get_tokenizer_id(self, instance_id: str) -> str: + return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer) + + async def load_model(self, instance_id: str) -> None: + self._get_instance(instance_id).load_model() + + async def clip_layer(self, instance_id: str, layer_idx: int) -> None: + self._get_instance(instance_id).clip_layer(layer_idx) + + async def set_tokenizer_option( + self, instance_id: str, option_name: str, value: Any + ) -> None: + self._get_instance(instance_id).set_tokenizer_option(option_name, value) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def set_property(self, instance_id: str, name: str, value: Any) -> None: + if name not in self._allowed_setters: + raise PermissionError(f"Setting '{name}' is not allowed via RPC") + setattr(self._get_instance(instance_id), name, value) + + async def tokenize( + self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any + ) -> Any: + return self._get_instance(instance_id).tokenize( + text, return_word_ids=return_word_ids, **kwargs + ) + + async def encode(self, instance_id: str, text: str) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(text)) + + async def encode_from_tokens( + self, + instance_id: str, + tokens: Any, + return_pooled: bool = False, + return_dict: bool = False, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens( + tokens, return_pooled=return_pooled, return_dict=return_dict + ) + ) + + async def encode_from_tokens_scheduled( + self, + instance_id: str, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return detach_if_grad( + self._get_instance(instance_id).encode_from_tokens_scheduled( + tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar + ) + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch=strength_patch, strength_model=strength_model + ) + + async def get_key_patches(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_key_patches() + + async def load_sd( + self, instance_id: str, sd: dict, full_model: bool = False + ) -> Any: + return self._get_instance(instance_id).load_sd(sd, full_model=full_model) + + async def get_sd(self, instance_id: str) -> Any: + return self._get_instance(instance_id).get_sd() + + async def clone(self, instance_id: str) -> str: + return self.register(self._get_instance(instance_id).clone()) + + +class CLIPProxy(BaseProxy[CLIPRegistry]): + _registry_class = CLIPRegistry + __module__ = "comfy.sd" + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + @property + def patcher(self) -> "ModelPatcherProxy": + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @patcher.setter + def patcher(self, value: Any) -> None: + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + if isinstance(value, ModelPatcherProxy): + self._patcher_proxy = value + else: + logger.warning( + f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}" + ) + + @property + def cond_stage_model(self) -> CondStageModelProxy: + if not hasattr(self, "_cond_stage_model_proxy"): + csm_id = self._call_rpc("get_cond_stage_model_id") + self._cond_stage_model_proxy = CondStageModelProxy( + csm_id, manage_lifecycle=False + ) + return self._cond_stage_model_proxy + + @property + def tokenizer(self) -> TokenizerProxy: + if not hasattr(self, "_tokenizer_proxy"): + tok_id = self._call_rpc("get_tokenizer_id") + self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False) + return self._tokenizer_proxy + + def load_model(self) -> ModelPatcherProxy: + self._call_rpc("load_model") + return self.patcher + + @property + def layer_idx(self) -> Optional[int]: + return self._call_rpc("get_property", "layer_idx") + + @layer_idx.setter + def layer_idx(self, value: Optional[int]) -> None: + self._call_rpc("set_property", "layer_idx", value) + + @property + def tokenizer_options(self) -> dict: + return self._call_rpc("get_property", "tokenizer_options") + + @tokenizer_options.setter + def tokenizer_options(self, value: dict) -> None: + self._call_rpc("set_property", "tokenizer_options", value) + + @property + def use_clip_schedule(self) -> bool: + return self._call_rpc("get_property", "use_clip_schedule") + + @use_clip_schedule.setter + def use_clip_schedule(self, value: bool) -> None: + self._call_rpc("set_property", "use_clip_schedule", value) + + @property + def apply_hooks_to_conds(self) -> Any: + return self._call_rpc("get_property", "apply_hooks_to_conds") + + @apply_hooks_to_conds.setter + def apply_hooks_to_conds(self, value: Any) -> None: + self._call_rpc("set_property", "apply_hooks_to_conds", value) + + def clip_layer(self, layer_idx: int) -> None: + return self._call_rpc("clip_layer", layer_idx) + + def set_tokenizer_option(self, option_name: str, value: Any) -> None: + return self._call_rpc("set_tokenizer_option", option_name, value) + + def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any: + return self._call_rpc( + "tokenize", text, return_word_ids=return_word_ids, **kwargs + ) + + def encode(self, text: str) -> Any: + return self._call_rpc("encode", text) + + def encode_from_tokens( + self, tokens: Any, return_pooled: bool = False, return_dict: bool = False + ) -> Any: + res = self._call_rpc( + "encode_from_tokens", + tokens, + return_pooled=return_pooled, + return_dict=return_dict, + ) + if return_pooled and isinstance(res, list) and not return_dict: + return tuple(res) + return res + + def encode_from_tokens_scheduled( + self, + tokens: Any, + unprojected: bool = False, + add_dict: Optional[dict] = None, + show_pbar: bool = True, + ) -> Any: + add_dict = add_dict or {} + return self._call_rpc( + "encode_from_tokens_scheduled", + tokens, + unprojected=unprojected, + add_dict=add_dict, + show_pbar=show_pbar, + ) + + def add_patches( + self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0 + ) -> Any: + return self._call_rpc( + "add_patches", + patches, + strength_patch=strength_patch, + strength_model=strength_model, + ) + + def get_key_patches(self) -> Any: + return self._call_rpc("get_key_patches") + + def load_sd(self, sd: dict, full_model: bool = False) -> Any: + return self._call_rpc("load_sd", sd, full_model=full_model) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def clone(self) -> CLIPProxy: + new_id = self._call_rpc("clone") + return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS) + + +if not IS_CHILD_PROCESS: + _CLIP_REGISTRY_SINGLETON = CLIPRegistry() + _COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry() + _TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry() diff --git a/comfy/isolation/model_patcher_proxy.py b/comfy/isolation/model_patcher_proxy.py new file mode 100644 index 000000000..6769996b8 --- /dev/null +++ b/comfy/isolation/model_patcher_proxy.py @@ -0,0 +1,820 @@ +# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access +# RPC proxy for ModelPatcher (parent process) +from __future__ import annotations + +import logging +from typing import Any, Optional, List, Set, Dict, Callable + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, +) +from comfy.isolation.model_patcher_proxy_registry import ( + ModelPatcherRegistry, + AutoPatcherEjector, +) + +logger = logging.getLogger(__name__) + + +class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]): + _registry_class = ModelPatcherRegistry + __module__ = "comfy.model_patcher" + _APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024 + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + self._registry_class, self._registry_class.get_remote_id() + ) + else: + self._rpc_caller = self._registry + return self._rpc_caller + + def get_all_callbacks(self, call_type: str = None) -> Any: + return self._call_rpc("get_all_callbacks", call_type) + + def get_all_wrappers(self, wrapper_type: str = None) -> Any: + return self._call_rpc("get_all_wrappers", wrapper_type) + + def _load_list(self, *args, **kwargs) -> Any: + return self._call_rpc("load_list_internal", *args, **kwargs) + + def prepare_hook_patches_current_keyframe( + self, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._call_rpc( + "prepare_hook_patches_current_keyframe", t, hook_group, model_options + ) + + def add_hook_patches( + self, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + self._call_rpc( + "add_hook_patches", hook, patches, strength_patch, strength_model + ) + + def clear_cached_hook_weights(self) -> None: + self._call_rpc("clear_cached_hook_weights") + + def get_combined_hook_patches(self, hooks: Any) -> Any: + return self._call_rpc("get_combined_hook_patches", hooks) + + def get_additional_models_with_key(self, key: str) -> Any: + return self._call_rpc("get_additional_models_with_key", key) + + @property + def object_patches(self) -> Any: + return self._call_rpc("get_object_patches") + + @property + def patches(self) -> Any: + res = self._call_rpc("get_patches") + if isinstance(res, dict): + new_res = {} + for k, v in res.items(): + new_list = [] + for item in v: + if isinstance(item, list): + new_list.append(tuple(item)) + else: + new_list.append(item) + new_res[k] = new_list + return new_res + return res + + @property + def pinned(self) -> Set: + val = self._call_rpc("get_patcher_attr", "pinned") + return set(val) if val is not None else set() + + @property + def hook_patches(self) -> Dict: + val = self._call_rpc("get_patcher_attr", "hook_patches") + if val is None: + return {} + try: + from comfy.hooks import _HookRef + import json + + new_val = {} + for k, v in val.items(): + if isinstance(k, str): + if k.startswith("PYISOLATE_HOOKREF:"): + ref_id = k.split(":", 1)[1] + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + elif k.startswith("__pyisolate_key__"): + try: + json_str = k[len("__pyisolate_key__") :] + data = json.loads(json_str) + ref_id = None + if isinstance(data, list): + for item in data: + if ( + isinstance(item, list) + and len(item) == 2 + and item[0] == "id" + ): + ref_id = item[1] + break + if ref_id: + h = _HookRef() + h._pyisolate_id = ref_id + new_val[h] = v + else: + new_val[k] = v + except Exception: + new_val[k] = v + else: + new_val[k] = v + else: + new_val[k] = v + return new_val + except ImportError: + return val + + def set_hook_mode(self, hook_mode: Any) -> None: + self._call_rpc("set_hook_mode", hook_mode) + + def register_all_hook_patches( + self, + hooks: Any, + target_dict: Any, + model_options: Any = None, + registered: Any = None, + ) -> None: + self._call_rpc( + "register_all_hook_patches", hooks, target_dict, model_options, registered + ) + + def is_clone(self, other: Any) -> bool: + if isinstance(other, ModelPatcherProxy): + return self._call_rpc("is_clone_by_id", other._instance_id) + return False + + def clone(self) -> ModelPatcherProxy: + new_id = self._call_rpc("clone") + return ModelPatcherProxy( + new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS + ) + + def clone_has_same_weights(self, clone: Any) -> bool: + if isinstance(clone, ModelPatcherProxy): + return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id) + if not IS_CHILD_PROCESS: + return self._call_rpc("is_clone", clone) + return False + + def get_model_object(self, name: str) -> Any: + return self._call_rpc("get_model_object", name) + + @property + def model_options(self) -> dict: + data = self._call_rpc("get_model_options") + import json + + def _decode_keys(obj): + if isinstance(obj, dict): + new_d = {} + for k, v in obj.items(): + if isinstance(k, str) and k.startswith("__pyisolate_key__"): + try: + json_str = k[17:] + val = json.loads(json_str) + if isinstance(val, list): + val = tuple(val) + new_d[val] = _decode_keys(v) + except: + new_d[k] = _decode_keys(v) + else: + new_d[k] = _decode_keys(v) + return new_d + if isinstance(obj, list): + return [_decode_keys(x) for x in obj] + return obj + + return _decode_keys(data) + + @model_options.setter + def model_options(self, value: dict) -> None: + self._call_rpc("set_model_options", value) + + def apply_hooks(self, hooks: Any) -> Any: + return self._call_rpc("apply_hooks", hooks) + + def prepare_state(self, timestep: Any) -> Any: + return self._call_rpc("prepare_state", timestep) + + def restore_hook_patches(self) -> None: + self._call_rpc("restore_hook_patches") + + def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None: + self._call_rpc("unpatch_hooks", whitelist_keys_set) + + def model_patches_to(self, device: Any) -> Any: + return self._call_rpc("model_patches_to", device) + + def partially_load( + self, device: Any, extra_memory: Any, force_patch_weights: bool = False + ) -> Any: + return self._call_rpc( + "partially_load", device, extra_memory, force_patch_weights + ) + + def partially_unload( + self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False + ) -> int: + return self._call_rpc( + "partially_unload", device_to, memory_to_free, force_patch_weights + ) + + def load( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._call_rpc( + "load", device_to, lowvram_model_memory, force_patch_weights, full_load + ) + + def patch_model( + self, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> Any: + self._call_rpc( + "patch_model", + device_to, + lowvram_model_memory, + load_weights, + force_patch_weights, + ) + return self + + def unpatch_model( + self, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._call_rpc("unpatch_model", device_to, unpatch_weights) + + def detach(self, unpatch_all: bool = True) -> Any: + self._call_rpc("detach", unpatch_all) + return self.model + + def _cpu_tensor_bytes(self, obj: Any) -> int: + import torch + + if isinstance(obj, torch.Tensor): + if obj.device.type == "cpu": + return obj.nbytes + return 0 + if isinstance(obj, dict): + return sum(self._cpu_tensor_bytes(v) for v in obj.values()) + if isinstance(obj, (list, tuple)): + return sum(self._cpu_tensor_bytes(v) for v in obj) + return 0 + + def _ensure_apply_model_headroom(self, required_bytes: int) -> bool: + if required_bytes <= 0: + return True + + import torch + import comfy.model_management as model_management + + target_raw = self.load_device + try: + if isinstance(target_raw, torch.device): + target = target_raw + elif isinstance(target_raw, str): + target = torch.device(target_raw) + elif isinstance(target_raw, int): + target = torch.device(f"cuda:{target_raw}") + else: + target = torch.device(target_raw) + except Exception: + return True + + if target.type != "cuda": + return True + + required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES + if model_management.get_free_memory(target) >= required: + return True + + model_management.cleanup_models_gc() + model_management.cleanup_models() + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.free_memory(required, target, for_dynamic=True) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + # Escalate to non-dynamic unloading before dispatching CUDA transfer. + model_management.free_memory(required, target, for_dynamic=False) + model_management.soft_empty_cache() + + if model_management.get_free_memory(target) < required: + model_management.load_models_gpu( + [self], + minimum_memory_required=required, + ) + + return model_management.get_free_memory(target) >= required + + def apply_model(self, *args, **kwargs) -> Any: + import torch + + required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) + self._ensure_apply_model_headroom(required_bytes) + + def _to_cuda(obj: Any) -> Any: + if isinstance(obj, torch.Tensor) and obj.device.type == "cpu": + return obj.to("cuda") + if isinstance(obj, dict): + return {k: _to_cuda(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_to_cuda(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_to_cuda(v) for v in obj) + return obj + + try: + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + except torch.OutOfMemoryError: + self._ensure_apply_model_headroom(required_bytes) + args_cuda = _to_cuda(args) + kwargs_cuda = _to_cuda(kwargs) + + return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) + + def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: + keys = self._call_rpc("model_state_dict", filter_prefix) + return dict.fromkeys(keys, None) + + def add_patches(self, *args: Any, **kwargs: Any) -> Any: + res = self._call_rpc("add_patches", *args, **kwargs) + if isinstance(res, list): + return [tuple(x) if isinstance(x, list) else x for x in res] + return res + + def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any: + return self._call_rpc("get_key_patches", filter_prefix) + + def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + self._call_rpc("patch_weight_to_device", key, device_to, inplace_update) + + def pin_weight_to_device(self, key): + self._call_rpc("pin_weight_to_device", key) + + def unpin_weight(self, key): + self._call_rpc("unpin_weight", key) + + def unpin_all_weights(self): + self._call_rpc("unpin_all_weights") + + def calculate_weight(self, patches, weight, key, intermediate_dtype=None): + return self._call_rpc( + "calculate_weight", patches, weight, key, intermediate_dtype + ) + + def inject_model(self) -> None: + self._call_rpc("inject_model") + + def eject_model(self) -> None: + self._call_rpc("eject_model") + + def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any: + return AutoPatcherEjector( + self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only + ) + + @property + def is_injected(self) -> bool: + return self._call_rpc("get_is_injected") + + @property + def skip_injection(self) -> bool: + return self._call_rpc("get_skip_injection") + + @skip_injection.setter + def skip_injection(self, value: bool) -> None: + self._call_rpc("set_skip_injection", value) + + def clean_hooks(self) -> None: + self._call_rpc("clean_hooks") + + def pre_run(self) -> None: + self._call_rpc("pre_run") + + def cleanup(self) -> None: + try: + self._call_rpc("cleanup") + except Exception: + logger.debug( + "ModelPatcherProxy cleanup RPC failed for %s", + self._instance_id, + exc_info=True, + ) + finally: + super().cleanup() + + @property + def model(self) -> _InnerModelProxy: + return _InnerModelProxy(self) + + def __getattr__(self, name: str) -> Any: + _whitelisted_attrs = { + "hook_patches_backup", + "hook_backup", + "cached_hook_patches", + "current_hooks", + "forced_hooks", + "is_clip", + "patches_uuid", + "pinned", + "attachments", + "additional_models", + "injections", + "hook_patches", + "model_lowvram", + "model_loaded_weight_memory", + "backup", + "object_patches_backup", + "weight_wrapper_patches", + "weight_inplace_update", + "force_cast_weights", + } + if name in _whitelisted_attrs: + return self._call_rpc("get_patcher_attr", name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def load_lora( + self, + lora_path: str, + strength_model: float, + clip: Optional[Any] = None, + strength_clip: float = 1.0, + ) -> tuple: + clip_id = None + if clip is not None: + clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None)) + result = self._call_rpc( + "load_lora", lora_path, strength_model, clip_id, strength_clip + ) + new_model = None + if result.get("model_id"): + new_model = ModelPatcherProxy( + result["model_id"], + self._registry, + manage_lifecycle=not IS_CHILD_PROCESS, + ) + new_clip = None + if result.get("clip_id"): + from comfy.isolation.clip_proxy import CLIPProxy + + new_clip = CLIPProxy(result["clip_id"]) + return (new_model, new_clip) + + @property + def load_device(self) -> Any: + return self._call_rpc("get_load_device") + + @property + def offload_device(self) -> Any: + return self._call_rpc("get_offload_device") + + @property + def device(self) -> Any: + return self.load_device + + def current_loaded_device(self) -> Any: + return self._call_rpc("current_loaded_device") + + @property + def size(self) -> int: + return self._call_rpc("get_size") + + def model_size(self) -> Any: + return self._call_rpc("model_size") + + def loaded_size(self) -> Any: + return self._call_rpc("loaded_size") + + def get_ram_usage(self) -> int: + return self._call_rpc("get_ram_usage") + + def lowvram_patch_counter(self) -> int: + return self._call_rpc("lowvram_patch_counter") + + def memory_required(self, input_shape: Any) -> Any: + return self._call_rpc("memory_required", input_shape) + + def is_dynamic(self) -> bool: + return bool(self._call_rpc("is_dynamic")) + + def get_free_memory(self, device: Any) -> Any: + return self._call_rpc("get_free_memory", device) + + def partially_unload_ram(self, ram_to_unload: int) -> Any: + return self._call_rpc("partially_unload_ram", ram_to_unload) + + def model_dtype(self) -> Any: + res = self._call_rpc("model_dtype") + if isinstance(res, str) and res.startswith("torch."): + try: + import torch + + attr = res.split(".")[-1] + if hasattr(torch, attr): + return getattr(torch, attr) + except ImportError: + pass + return res + + @property + def hook_mode(self) -> Any: + return self._call_rpc("get_hook_mode") + + @hook_mode.setter + def hook_mode(self, value: Any) -> None: + self._call_rpc("set_hook_mode", value) + + def set_model_sampler_cfg_function( + self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_cfg_function", + sampler_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_post_cfg_function( + self, post_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_post_cfg_function", + post_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_pre_cfg_function( + self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False + ) -> None: + self._call_rpc( + "set_model_sampler_pre_cfg_function", + pre_cfg_function, + disable_cfg1_optimization, + ) + + def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None: + self._call_rpc("set_model_sampler_calc_cond_batch_function", fn) + + def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None: + self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function) + + def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None: + self._call_rpc("set_model_denoise_mask_function", denoise_mask_function) + + def set_model_patch(self, patch: Any, name: str) -> None: + self._call_rpc("set_model_patch", patch, name) + + def set_model_patch_replace( + self, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._call_rpc( + "set_model_patch_replace", + patch, + name, + block_name, + number, + transformer_index, + ) + + def set_model_attn1_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn1", block_name, number, transformer_index + ) + + def set_model_attn2_replace( + self, + patch: Any, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self.set_model_patch_replace( + patch, "attn2", block_name, number, transformer_index + ) + + def set_model_attn1_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "attn2_output_patch") + + def set_model_input_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch") + + def set_model_input_block_patch_after_skip(self, patch: Any) -> None: + self.set_model_patch(patch, "input_block_patch_after_skip") + + def set_model_output_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "output_block_patch") + + def set_model_emb_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "emb_patch") + + def set_model_forward_timestep_embed_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "forward_timestep_embed_patch") + + def set_model_double_block_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "double_block") + + def set_model_post_input_patch(self, patch: Any) -> None: + self.set_model_patch(patch, "post_input") + + def set_model_rope_options( + self, + scale_x=1.0, + shift_x=0.0, + scale_y=1.0, + shift_y=0.0, + scale_t=1.0, + shift_t=0.0, + **kwargs: Any, + ) -> None: + options = { + "scale_x": scale_x, + "shift_x": shift_x, + "scale_y": scale_y, + "shift_y": shift_y, + "scale_t": scale_t, + "shift_t": shift_t, + } + options.update(kwargs) + self._call_rpc("set_model_rope_options", options) + + def set_model_compute_dtype(self, dtype: Any) -> None: + self._call_rpc("set_model_compute_dtype", dtype) + + def add_object_patch(self, name: str, obj: Any) -> None: + self._call_rpc("add_object_patch", name, obj) + + def add_weight_wrapper(self, name: str, function: Any) -> None: + self._call_rpc("add_weight_wrapper", name, function) + + def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None: + self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn) + + def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None: + self.add_wrapper_with_key(wrapper_type, None, wrapper) + + def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None: + self._call_rpc("remove_wrappers_with_key", wrapper_type, key) + + @property + def wrappers(self) -> Any: + return self._call_rpc("get_wrappers") + + def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None: + self._call_rpc("add_callback_with_key", call_type, key, callback) + + def add_callback(self, call_type: str, callback: Any) -> None: + self.add_callback_with_key(call_type, None, callback) + + def remove_callbacks_with_key(self, call_type: str, key: str) -> None: + self._call_rpc("remove_callbacks_with_key", call_type, key) + + @property + def callbacks(self) -> Any: + return self._call_rpc("get_callbacks") + + def set_attachments(self, key: str, attachment: Any) -> None: + self._call_rpc("set_attachments", key, attachment) + + def get_attachment(self, key: str) -> Any: + return self._call_rpc("get_attachment", key) + + def remove_attachments(self, key: str) -> None: + self._call_rpc("remove_attachments", key) + + def set_injections(self, key: str, injections: Any) -> None: + self._call_rpc("set_injections", key, injections) + + def get_injections(self, key: str) -> Any: + return self._call_rpc("get_injections", key) + + def remove_injections(self, key: str) -> None: + self._call_rpc("remove_injections", key) + + def set_additional_models(self, key: str, models: Any) -> None: + ids = [m._instance_id for m in models] + self._call_rpc("set_additional_models", key, ids) + + def remove_additional_models(self, key: str) -> None: + self._call_rpc("remove_additional_models", key) + + def get_nested_additional_models(self) -> Any: + return self._call_rpc("get_nested_additional_models") + + def get_additional_models(self) -> List[ModelPatcherProxy]: + ids = self._call_rpc("get_additional_models") + return [ + ModelPatcherProxy( + mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS + ) + for mid in ids + ] + + def model_patches_models(self) -> Any: + return self._call_rpc("model_patches_models") + + @property + def parent(self) -> Any: + return self._call_rpc("get_parent") + + +class _InnerModelProxy: + def __init__(self, parent: ModelPatcherProxy): + self._parent = parent + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + raise AttributeError(name) + if name in ( + "model_config", + "latent_format", + "model_type", + "current_weight_patches_uuid", + ): + return self._parent._call_rpc("get_inner_model_attr", name) + if name == "load_device": + return self._parent._call_rpc("get_inner_model_attr", "load_device") + if name == "device": + return self._parent._call_rpc("get_inner_model_attr", "device") + if name == "current_patcher": + return ModelPatcherProxy( + self._parent._instance_id, + self._parent._registry, + manage_lifecycle=False, + ) + if name == "model_sampling": + return self._parent._call_rpc("get_model_object", "model_sampling") + if name == "extra_conds_shapes": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds_shapes", a, k + ) + if name == "extra_conds": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_extra_conds", a, k + ) + if name == "memory_required": + return lambda *a, **k: self._parent._call_rpc( + "inner_model_memory_required", a, k + ) + if name == "apply_model": + # Delegate to parent's method to get the CPU->CUDA optimization + return self._parent.apply_model + if name == "process_latent_in": + return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k) + if name == "process_latent_out": + return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k) + if name == "scale_latent_inpaint": + return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k) + if name == "diffusion_model": + return self._parent._call_rpc("get_inner_model_attr", "diffusion_model") + raise AttributeError(f"'{name}' not supported on isolated InnerModel") diff --git a/comfy/isolation/model_patcher_proxy_registry.py b/comfy/isolation/model_patcher_proxy_registry.py new file mode 100644 index 000000000..7224c3233 --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_registry.py @@ -0,0 +1,875 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import +# RPC server for ModelPatcher isolation (child process) +from __future__ import annotations + +import gc +import logging +from typing import Any, Optional, List + +try: + from comfy.model_patcher import AutoPatcherEjector +except ImportError: + + class AutoPatcherEjector: + def __init__(self, model, skip_and_inject_on_exit_only=False): + self.model = model + self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only + self.prev_skip_injection = False + self.was_injected = False + + def __enter__(self): + self.was_injected = False + self.prev_skip_injection = self.model.skip_injection + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = True + if self.model.is_injected: + self.model.eject_model() + self.was_injected = True + + def __exit__(self, *args): + if self.skip_and_inject_on_exit_only: + self.model.skip_injection = self.prev_skip_injection + self.model.inject_model() + if self.was_injected and not self.model.skip_injection: + self.model.inject_model() + self.model.skip_injection = self.prev_skip_injection + + +from comfy.isolation.proxies.base import ( + BaseRegistry, + detach_if_grad, +) + +logger = logging.getLogger(__name__) + + +class ModelPatcherRegistry(BaseRegistry[Any]): + _type_prefix = "model" + + def __init__(self) -> None: + super().__init__() + self._pending_cleanup_ids: set[str] = set() + + async def clone(self, instance_id: str) -> str: + instance = self._get_instance(instance_id) + new_model = instance.clone() + return self.register(new_model) + + async def is_clone(self, instance_id: str, other: Any) -> bool: + instance = self._get_instance(instance_id) + if hasattr(other, "model"): + return instance.is_clone(other) + return False + + async def get_model_object(self, instance_id: str, name: str) -> Any: + instance = self._get_instance(instance_id) + if name == "model": + return f"" + result = instance.get_model_object(name) + if name == "model_sampling": + from comfy.isolation.model_sampling_proxy import ( + ModelSamplingRegistry, + ModelSamplingProxy, + ) + + registry = ModelSamplingRegistry() + sampling_id = registry.register(result) + return ModelSamplingProxy(sampling_id, registry) + return detach_if_grad(result) + + async def get_model_options(self, instance_id: str) -> dict: + instance = self._get_instance(instance_id) + import copy + + opts = copy.deepcopy(instance.model_options) + return self._sanitize_rpc_result(opts) + + async def set_model_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).model_options = options + + async def get_patcher_attr(self, instance_id: str, name: str) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), name, None) + ) + + async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any: + instance = self._get_instance(instance_id) + sd_keys = instance.model.state_dict().keys() + return dict.fromkeys(sd_keys, None) + + def _sanitize_rpc_result(self, obj, seen=None): + if seen is None: + seen = set() + if obj is None: + return None + if isinstance(obj, (bool, int, float, str)): + if isinstance(obj, str) and len(obj) > 500000: + return f"" + return obj + obj_id = id(obj) + if obj_id in seen: + return None + seen.add(obj_id) + if isinstance(obj, (list, tuple)): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, set): + return [self._sanitize_rpc_result(x, seen) for x in obj] + if isinstance(obj, dict): + new_dict = {} + for k, v in obj.items(): + if isinstance(k, tuple): + import json + + try: + key_str = "__pyisolate_key__" + json.dumps(list(k)) + new_dict[key_str] = self._sanitize_rpc_result(v, seen) + except Exception: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + else: + new_dict[str(k)] = self._sanitize_rpc_result(v, seen) + return new_dict + if ( + hasattr(obj, "__dict__") + and not hasattr(obj, "__get__") + and not hasattr(obj, "__call__") + ): + return self._sanitize_rpc_result(obj.__dict__, seen) + if hasattr(obj, "items") and hasattr(obj, "get"): + return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()} + return None + + async def get_load_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).load_device + + async def get_offload_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).offload_device + + async def current_loaded_device(self, instance_id: str) -> Any: + return self._get_instance(instance_id).current_loaded_device() + + async def get_size(self, instance_id: str) -> int: + return self._get_instance(instance_id).size + + async def model_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).model_size() + + async def loaded_size(self, instance_id: str) -> Any: + return self._get_instance(instance_id).loaded_size() + + async def get_ram_usage(self, instance_id: str) -> int: + return self._get_instance(instance_id).get_ram_usage() + + async def lowvram_patch_counter(self, instance_id: str) -> int: + return self._get_instance(instance_id).lowvram_patch_counter() + + async def memory_required(self, instance_id: str, input_shape: Any) -> Any: + return self._get_instance(instance_id).memory_required(input_shape) + + async def is_dynamic(self, instance_id: str) -> bool: + instance = self._get_instance(instance_id) + if hasattr(instance, "is_dynamic"): + return bool(instance.is_dynamic()) + return False + + async def get_free_memory(self, instance_id: str, device: Any) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "get_free_memory"): + return instance.get_free_memory(device) + import comfy.model_management + + return comfy.model_management.get_free_memory(device) + + async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any: + instance = self._get_instance(instance_id) + if hasattr(instance, "partially_unload_ram"): + return instance.partially_unload_ram(ram_to_unload) + return None + + async def model_dtype(self, instance_id: str) -> Any: + return self._get_instance(instance_id).model_dtype() + + async def model_patches_to(self, instance_id: str, device: Any) -> Any: + return self._get_instance(instance_id).model_patches_to(device) + + async def partially_load( + self, + instance_id: str, + device: Any, + extra_memory: Any, + force_patch_weights: bool = False, + ) -> Any: + return self._get_instance(instance_id).partially_load( + device, extra_memory, force_patch_weights=force_patch_weights + ) + + async def partially_unload( + self, + instance_id: str, + device_to: Any, + memory_to_free: int = 0, + force_patch_weights: bool = False, + ) -> int: + return self._get_instance(instance_id).partially_unload( + device_to, memory_to_free, force_patch_weights + ) + + async def load( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + force_patch_weights: bool = False, + full_load: bool = False, + ) -> None: + self._get_instance(instance_id).load( + device_to, lowvram_model_memory, force_patch_weights, full_load + ) + + async def patch_model( + self, + instance_id: str, + device_to: Any = None, + lowvram_model_memory: int = 0, + load_weights: bool = True, + force_patch_weights: bool = False, + ) -> None: + try: + self._get_instance(instance_id).patch_model( + device_to, lowvram_model_memory, load_weights, force_patch_weights + ) + except AttributeError as e: + logger.error( + f"Isolation Error: Failed to patch model attribute: {e}. Skipping." + ) + return + + async def unpatch_model( + self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True + ) -> None: + self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights) + + async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: + self._get_instance(instance_id).detach(unpatch_all) + + async def prepare_state(self, instance_id: str, timestep: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.prepare_state(timestep) + + async def pre_run(self, instance_id: str) -> None: + self._get_instance(instance_id).pre_run() + + async def cleanup(self, instance_id: str) -> None: + try: + instance = self._get_instance(instance_id) + except Exception: + logger.debug( + "ModelPatcher cleanup requested for missing instance %s", + instance_id, + exc_info=True, + ) + return + + try: + instance.cleanup() + finally: + with self._lock: + self._pending_cleanup_ids.add(instance_id) + gc.collect() + + def sweep_pending_cleanup(self) -> int: + removed = 0 + with self._lock: + pending_ids = list(self._pending_cleanup_ids) + self._pending_cleanup_ids.clear() + for instance_id in pending_ids: + instance = self._registry.pop(instance_id, None) + if instance is None: + continue + self._id_map.pop(id(instance), None) + removed += 1 + + gc.collect() + return removed + + def purge_all(self) -> int: + with self._lock: + removed = len(self._registry) + self._registry.clear() + self._id_map.clear() + self._pending_cleanup_ids.clear() + gc.collect() + return removed + + async def apply_hooks(self, instance_id: str, hooks: Any) -> Any: + instance = self._get_instance(instance_id) + cp = getattr(instance.model, "current_patcher", instance) + if cp is None: + cp = instance + return cp.apply_hooks(hooks=hooks) + + async def clean_hooks(self, instance_id: str) -> None: + self._get_instance(instance_id).clean_hooks() + + async def restore_hook_patches(self, instance_id: str) -> None: + self._get_instance(instance_id).restore_hook_patches() + + async def unpatch_hooks( + self, instance_id: str, whitelist_keys_set: Optional[set] = None + ) -> None: + self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set) + + async def register_all_hook_patches( + self, + instance_id: str, + hooks: Any, + target_dict: Any, + model_options: Any, + registered: Any, + ) -> None: + from types import SimpleNamespace + import comfy.hooks + + instance = self._get_instance(instance_id) + if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"): + hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks + new_hooks = comfy.hooks.HookGroup() + if hasattr(hook_data, "hooks"): + new_hooks.hooks = ( + hook_data["hooks"] + if isinstance(hook_data, dict) + else hook_data.hooks + ) + hooks = new_hooks + instance.register_all_hook_patches( + hooks, target_dict, model_options, registered + ) + + async def get_hook_mode(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "hook_mode", None) + + async def set_hook_mode(self, instance_id: str, value: Any) -> None: + setattr(self._get_instance(instance_id), "hook_mode", value) + + async def inject_model(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + try: + instance.inject_model() + except AttributeError as e: + if "inject" in str(e): + logger.error( + "Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping." + ) + return + raise e + + async def eject_model(self, instance_id: str) -> None: + self._get_instance(instance_id).eject_model() + + async def get_is_injected(self, instance_id: str) -> bool: + return self._get_instance(instance_id).is_injected + + async def set_skip_injection(self, instance_id: str, value: bool) -> None: + self._get_instance(instance_id).skip_injection = value + + async def get_skip_injection(self, instance_id: str) -> bool: + return self._get_instance(instance_id).skip_injection + + async def set_model_sampler_cfg_function( + self, + instance_id: str, + sampler_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + if not callable(sampler_cfg_function): + logger.error( + f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping." + ) + return + self._get_instance(instance_id).set_model_sampler_cfg_function( + sampler_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_post_cfg_function( + self, + instance_id: str, + post_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_post_cfg_function( + post_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_pre_cfg_function( + self, + instance_id: str, + pre_cfg_function: Any, + disable_cfg1_optimization: bool = False, + ) -> None: + self._get_instance(instance_id).set_model_sampler_pre_cfg_function( + pre_cfg_function, disable_cfg1_optimization + ) + + async def set_model_sampler_calc_cond_batch_function( + self, instance_id: str, fn: Any + ) -> None: + self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn) + + async def set_model_unet_function_wrapper( + self, instance_id: str, unet_wrapper_function: Any + ) -> None: + self._get_instance(instance_id).set_model_unet_function_wrapper( + unet_wrapper_function + ) + + async def set_model_denoise_mask_function( + self, instance_id: str, denoise_mask_function: Any + ) -> None: + self._get_instance(instance_id).set_model_denoise_mask_function( + denoise_mask_function + ) + + async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None: + self._get_instance(instance_id).set_model_patch(patch, name) + + async def set_model_patch_replace( + self, + instance_id: str, + patch: Any, + name: str, + block_name: str, + number: int, + transformer_index: Optional[int] = None, + ) -> None: + self._get_instance(instance_id).set_model_patch_replace( + patch, name, block_name, number, transformer_index + ) + + async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_input_block_patch(patch) + + async def set_model_input_block_patch_after_skip( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch) + + async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_output_block_patch(patch) + + async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_emb_patch(patch) + + async def set_model_forward_timestep_embed_patch( + self, instance_id: str, patch: Any + ) -> None: + self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch) + + async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_double_block_patch(patch) + + async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None: + self._get_instance(instance_id).set_model_post_input_patch(patch) + + async def set_model_rope_options(self, instance_id: str, options: dict) -> None: + self._get_instance(instance_id).set_model_rope_options(**options) + + async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None: + self._get_instance(instance_id).set_model_compute_dtype(dtype) + + async def clone_has_same_weights_by_id( + self, instance_id: str, other_id: str + ) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if not other: + return False + return instance.clone_has_same_weights(other) + + async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any: + return self._get_instance(instance_id)._load_list(*args, **kwargs) + + async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool: + instance = self._get_instance(instance_id) + other = self._get_instance(other_id) + if hasattr(instance, "is_clone"): + return instance.is_clone(other) + return False + + async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None: + self._get_instance(instance_id).add_object_patch(name, obj) + + async def add_weight_wrapper( + self, instance_id: str, name: str, function: Any + ) -> None: + self._get_instance(instance_id).add_weight_wrapper(name, function) + + async def add_wrapper_with_key( + self, instance_id: str, wrapper_type: Any, key: str, fn: Any + ) -> None: + self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn) + + async def remove_wrappers_with_key( + self, instance_id: str, wrapper_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key) + + async def get_wrappers( + self, instance_id: str, wrapper_type: str = None, key: str = None + ) -> Any: + if wrapper_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "wrappers", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_wrappers(wrapper_type, key) + ) + + async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])( + wrapper_type + ) + ) + + async def add_callback_with_key( + self, instance_id: str, call_type: str, key: str, callback: Any + ) -> None: + self._get_instance(instance_id).add_callback_with_key(call_type, key, callback) + + async def remove_callbacks_with_key( + self, instance_id: str, call_type: str, key: str + ) -> None: + self._get_instance(instance_id).remove_callbacks_with_key(call_type, key) + + async def get_callbacks( + self, instance_id: str, call_type: str = None, key: str = None + ) -> Any: + if call_type is None and key is None: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "callbacks", {}) + ) + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_callbacks(call_type, key) + ) + + async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])( + call_type + ) + ) + + async def set_attachments( + self, instance_id: str, key: str, attachment: Any + ) -> None: + self._get_instance(instance_id).set_attachments(key, attachment) + + async def get_attachment(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_attachment(key) + ) + + async def remove_attachments(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_attachments(key) + + async def set_injections(self, instance_id: str, key: str, injections: Any) -> None: + self._get_instance(instance_id).set_injections(key, injections) + + async def get_injections(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_injections(key) + ) + + async def remove_injections(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_injections(key) + + async def set_additional_models( + self, instance_id: str, key: str, models: Any + ) -> None: + self._get_instance(instance_id).set_additional_models(key, models) + + async def remove_additional_models(self, instance_id: str, key: str) -> None: + self._get_instance(instance_id).remove_additional_models(key) + + async def get_nested_additional_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_nested_additional_models() + ) + + async def get_additional_models(self, instance_id: str) -> List[str]: + models = self._get_instance(instance_id).get_additional_models() + return [self.register(m) for m in models] + + async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).get_additional_models_with_key(key) + ) + + async def model_patches_models(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).model_patches_models() + ) + + async def get_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy()) + + async def get_object_patches(self, instance_id: str) -> Any: + return self._sanitize_rpc_result( + self._get_instance(instance_id).object_patches.copy() + ) + + async def add_patches( + self, + instance_id: str, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> Any: + return self._get_instance(instance_id).add_patches( + patches, strength_patch, strength_model + ) + + async def get_key_patches( + self, instance_id: str, filter_prefix: Optional[str] = None + ) -> Any: + res = self._get_instance(instance_id).get_key_patches() + if filter_prefix: + res = {k: v for k, v in res.items() if k.startswith(filter_prefix)} + safe_res = {} + for k, v in res.items(): + safe_res[k] = [ + f"" + if hasattr(t, "shape") + else str(t) + for t in v + ] + return safe_res + + async def add_hook_patches( + self, + instance_id: str, + hook: Any, + patches: Any, + strength_patch: float = 1.0, + strength_model: float = 1.0, + ) -> None: + if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook.hook_ref.items())) + except Exception: + hook.hook_ref = None + self._get_instance(instance_id).add_hook_patches( + hook, patches, strength_patch, strength_model + ) + + async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any: + if hooks is not None and hasattr(hooks, "hooks"): + for hook in getattr(hooks, "hooks", []): + hook_ref = getattr(hook, "hook_ref", None) + if isinstance(hook_ref, dict): + try: + hook.hook_ref = tuple(sorted(hook_ref.items())) + except Exception: + hook.hook_ref = None + res = self._get_instance(instance_id).get_combined_hook_patches(hooks) + return self._sanitize_rpc_result(res) + + async def clear_cached_hook_weights(self, instance_id: str) -> None: + self._get_instance(instance_id).clear_cached_hook_weights() + + async def prepare_hook_patches_current_keyframe( + self, instance_id: str, t: Any, hook_group: Any, model_options: Any + ) -> None: + self._get_instance(instance_id).prepare_hook_patches_current_keyframe( + t, hook_group, model_options + ) + + async def get_parent(self, instance_id: str) -> Any: + return getattr(self._get_instance(instance_id), "parent", None) + + async def patch_weight_to_device( + self, + instance_id: str, + key: str, + device_to: Any = None, + inplace_update: bool = False, + ) -> None: + self._get_instance(instance_id).patch_weight_to_device( + key, device_to, inplace_update + ) + + async def pin_weight_to_device(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.pin_weight_to_device(key) + + async def unpin_weight(self, instance_id: str, key: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_weight(key) + + async def unpin_all_weights(self, instance_id: str) -> None: + instance = self._get_instance(instance_id) + if hasattr(instance, "pinned") and isinstance(instance.pinned, list): + instance.pinned = set(instance.pinned) + instance.unpin_all_weights() + + async def calculate_weight( + self, + instance_id: str, + patches: Any, + weight: Any, + key: str, + intermediate_dtype: Any = float, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).calculate_weight( + patches, weight, key, intermediate_dtype + ) + ) + + async def get_inner_model_attr(self, instance_id: str, name: str) -> Any: + try: + return self._sanitize_rpc_result( + getattr(self._get_instance(instance_id).model, name) + ) + except AttributeError: + return None + + async def inner_model_memory_required( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._get_instance(instance_id).model.memory_required(*args, **kwargs) + + async def inner_model_extra_conds_shapes( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs) + + async def inner_model_extra_conds( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return self._get_instance(instance_id).model.extra_conds(*args, **kwargs) + + async def inner_model_state_dict( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs) + return { + k: {"numel": v.numel(), "element_size": v.element_size()} + for k, v in sd.items() + } + + async def inner_model_apply_model( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + instance = self._get_instance(instance_id) + target = getattr(instance, "load_device", None) + if target is None and args and hasattr(args[0], "device"): + target = args[0].device + elif target is None: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + + def _move(obj): + if target is None: + return obj + if isinstance(obj, (tuple, list)): + return type(obj)(_move(o) for o in obj) + if hasattr(obj, "to"): + return obj.to(target) + return obj + + moved_args = tuple(_move(a) for a in args) + moved_kwargs = {k: _move(v) for k, v in kwargs.items()} + result = instance.model.apply_model(*moved_args, **moved_kwargs) + return detach_if_grad(_move(result)) + + async def process_latent_in( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).model.process_latent_in(*args, **kwargs) + ) + + async def process_latent_out( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + instance = self._get_instance(instance_id) + result = instance.model.process_latent_out(*args, **kwargs) + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + return detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "process_latent_out: failed to move result to target device", + exc_info=True, + ) + return detach_if_grad(result) + + async def scale_latent_inpaint( + self, instance_id: str, args: tuple, kwargs: dict + ) -> Any: + instance = self._get_instance(instance_id) + result = instance.model.scale_latent_inpaint(*args, **kwargs) + try: + target = None + if args and hasattr(args[0], "device"): + target = args[0].device + elif kwargs: + for v in kwargs.values(): + if hasattr(v, "device"): + target = v.device + break + if target is not None and hasattr(result, "to"): + return detach_if_grad(result.to(target)) + except Exception: + logger.debug( + "scale_latent_inpaint: failed to move result to target device", + exc_info=True, + ) + return detach_if_grad(result) + + async def load_lora( + self, + instance_id: str, + lora_path: str, + strength_model: float, + clip_id: Optional[str] = None, + strength_clip: float = 1.0, + ) -> dict: + import comfy.utils + import comfy.sd + import folder_paths + from comfy.isolation.clip_proxy import CLIPRegistry + + model = self._get_instance(instance_id) + clip = None + if clip_id: + clip = CLIPRegistry()._get_instance(clip_id) + lora_full_path = folder_paths.get_full_path("loras", lora_path) + if lora_full_path is None: + raise ValueError(f"LoRA file not found: {lora_path}") + lora = comfy.utils.load_torch_file(lora_full_path) + new_model, new_clip = comfy.sd.load_lora_for_models( + model, clip, lora, strength_model, strength_clip + ) + new_model_id = self.register(new_model) if new_model else None + new_clip_id = ( + CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None + ) + return {"model_id": new_model_id, "clip_id": new_clip_id} diff --git a/comfy/isolation/model_patcher_proxy_utils.py b/comfy/isolation/model_patcher_proxy_utils.py new file mode 100644 index 000000000..b22e3464f --- /dev/null +++ b/comfy/isolation/model_patcher_proxy_utils.py @@ -0,0 +1,154 @@ +# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access +# Isolation utilities and serializers for ModelPatcherProxy +from __future__ import annotations + +import logging +import os +from typing import Any + +logger = logging.getLogger(__name__) + + +def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any: + from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry + from comfy.isolation.model_patcher_proxy import ModelPatcherProxy + + isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + is_child = os.environ.get("PYISOLATE_CHILD") == "1" + + if not isolation_active: + return model_patcher + if is_child: + return model_patcher + if isinstance(model_patcher, ModelPatcherProxy): + return model_patcher + + registry = ModelPatcherRegistry() + model_id = registry.register(model_patcher) + logger.debug(f"Isolated ModelPatcher: {model_id}") + return ModelPatcherProxy(model_id, registry, manage_lifecycle=True) + + +def register_hooks_serializers(registry=None): + from pyisolate._internal.serialization_registry import SerializerRegistry + import comfy.hooks + + if registry is None: + registry = SerializerRegistry.get_instance() + + def serialize_enum(obj): + return {"__enum__": f"{type(obj).__name__}.{obj.name}"} + + def deserialize_enum(data): + cls_name, val_name = data["__enum__"].split(".") + cls = getattr(comfy.hooks, cls_name) + return cls[val_name] + + registry.register("EnumHookType", serialize_enum, deserialize_enum) + registry.register("EnumHookScope", serialize_enum, deserialize_enum) + registry.register("EnumHookMode", serialize_enum, deserialize_enum) + registry.register("EnumWeightTarget", serialize_enum, deserialize_enum) + + def serialize_hook_group(obj): + return {"__type__": "HookGroup", "hooks": obj.hooks} + + def deserialize_hook_group(data): + hg = comfy.hooks.HookGroup() + for h in data["hooks"]: + hg.add(h) + return hg + + registry.register("HookGroup", serialize_hook_group, deserialize_hook_group) + + def serialize_dict_state(obj): + d = obj.__dict__.copy() + d["__type__"] = type(obj).__name__ + if "custom_should_register" in d: + del d["custom_should_register"] + return d + + def deserialize_dict_state_generic(cls): + def _deserialize(data): + h = cls() + h.__dict__.update(data) + return h + + return _deserialize + + def deserialize_hook_keyframe(data): + h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0)) + h.__dict__.update(data) + return h + + registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe) + + def deserialize_hook_keyframe_group(data): + h = comfy.hooks.HookKeyframeGroup() + h.__dict__.update(data) + return h + + registry.register( + "HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group + ) + + def deserialize_hook(data): + h = comfy.hooks.Hook() + h.__dict__.update(data) + return h + + registry.register("Hook", serialize_dict_state, deserialize_hook) + + def deserialize_weight_hook(data): + h = comfy.hooks.WeightHook() + h.__dict__.update(data) + return h + + registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook) + + def serialize_set(obj): + return {"__set__": list(obj)} + + def deserialize_set(data): + return set(data["__set__"]) + + registry.register("set", serialize_set, deserialize_set) + + try: + from comfy.weight_adapter.lora import LoRAAdapter + + def serialize_lora(obj): + return {"weights": {}, "loaded_keys": list(obj.loaded_keys)} + + def deserialize_lora(data): + return LoRAAdapter(set(data["loaded_keys"]), data["weights"]) + + registry.register("LoRAAdapter", serialize_lora, deserialize_lora) + except Exception: + pass + + try: + from comfy.hooks import _HookRef + import uuid + + def serialize_hook_ref(obj): + return { + "__hook_ref__": True, + "id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())), + } + + def deserialize_hook_ref(data): + h = _HookRef() + h._pyisolate_id = data.get("id", str(uuid.uuid4())) + return h + + registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref) + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to register _HookRef: {e}") + + +try: + register_hooks_serializers() +except Exception as e: + logger.error(f"Failed to initialize hook serializers: {e}") diff --git a/comfy/isolation/model_sampling_proxy.py b/comfy/isolation/model_sampling_proxy.py new file mode 100644 index 000000000..886c60409 --- /dev/null +++ b/comfy/isolation/model_sampling_proxy.py @@ -0,0 +1,253 @@ +# pylint: disable=import-outside-toplevel +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from comfy.isolation.proxies.base import ( + BaseProxy, + BaseRegistry, + detach_if_grad, + get_thread_loop, + run_coro_in_new_loop, +) + +logger = logging.getLogger(__name__) + + +def _prefer_device(*tensors: Any) -> Any: + try: + import torch + except Exception: + return None + for t in tensors: + if isinstance(t, torch.Tensor) and t.is_cuda: + return t.device + for t in tensors: + if isinstance(t, torch.Tensor): + return t.device + return None + + +def _to_device(obj: Any, device: Any) -> Any: + try: + import torch + except Exception: + return obj + if device is None: + return obj + if isinstance(obj, torch.Tensor): + if obj.device != device: + return obj.to(device) + return obj + if isinstance(obj, (list, tuple)): + converted = [_to_device(x, device) for x in obj] + return type(obj)(converted) if isinstance(obj, tuple) else converted + if isinstance(obj, dict): + return {k: _to_device(v, device) for k, v in obj.items()} + return obj + + +class ModelSamplingRegistry(BaseRegistry[Any]): + _type_prefix = "modelsampling" + + async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.calculate_input(sigma, noise)) + + async def calculate_denoised( + self, instance_id: str, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.calculate_denoised(sigma, model_output, model_input) + ) + + async def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad( + sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise) + ) + + async def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent)) + + async def timestep(self, instance_id: str, sigma: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.timestep(sigma) + + async def sigma(self, instance_id: str, timestep: Any) -> Any: + sampling = self._get_instance(instance_id) + return sampling.sigma(timestep) + + async def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + sampling = self._get_instance(instance_id) + return sampling.percent_to_sigma(percent) + + async def get_sigma_min(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_min) + + async def get_sigma_max(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_max) + + async def get_sigma_data(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigma_data) + + async def get_sigmas(self, instance_id: str) -> Any: + sampling = self._get_instance(instance_id) + return detach_if_grad(sampling.sigmas) + + async def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + sampling = self._get_instance(instance_id) + sampling.set_sigmas(sigmas) + + +class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]): + _registry_class = ModelSamplingRegistry + __module__ = "comfy.isolation.model_sampling_proxy" + + def _get_rpc(self) -> Any: + if self._rpc_caller is None: + from pyisolate._internal.rpc_protocol import get_child_rpc_instance + + rpc = get_child_rpc_instance() + if rpc is not None: + self._rpc_caller = rpc.create_caller( + ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id() + ) + else: + registry = ModelSamplingRegistry() + + class _LocalCaller: + def calculate_input( + self, instance_id: str, sigma: Any, noise: Any + ) -> Any: + return registry.calculate_input(instance_id, sigma, noise) + + def calculate_denoised( + self, + instance_id: str, + sigma: Any, + model_output: Any, + model_input: Any, + ) -> Any: + return registry.calculate_denoised( + instance_id, sigma, model_output, model_input + ) + + def noise_scaling( + self, + instance_id: str, + sigma: Any, + noise: Any, + latent_image: Any, + max_denoise: bool = False, + ) -> Any: + return registry.noise_scaling( + instance_id, sigma, noise, latent_image, max_denoise + ) + + def inverse_noise_scaling( + self, instance_id: str, sigma: Any, latent: Any + ) -> Any: + return registry.inverse_noise_scaling( + instance_id, sigma, latent + ) + + def timestep(self, instance_id: str, sigma: Any) -> Any: + return registry.timestep(instance_id, sigma) + + def sigma(self, instance_id: str, timestep: Any) -> Any: + return registry.sigma(instance_id, timestep) + + def percent_to_sigma(self, instance_id: str, percent: float) -> Any: + return registry.percent_to_sigma(instance_id, percent) + + def get_sigma_min(self, instance_id: str) -> Any: + return registry.get_sigma_min(instance_id) + + def get_sigma_max(self, instance_id: str) -> Any: + return registry.get_sigma_max(instance_id) + + def get_sigma_data(self, instance_id: str) -> Any: + return registry.get_sigma_data(instance_id) + + def get_sigmas(self, instance_id: str) -> Any: + return registry.get_sigmas(instance_id) + + def set_sigmas(self, instance_id: str, sigmas: Any) -> None: + return registry.set_sigmas(instance_id, sigmas) + + self._rpc_caller = _LocalCaller() + return self._rpc_caller + + def _call(self, method_name: str, *args: Any) -> Any: + rpc = self._get_rpc() + method = getattr(rpc, method_name) + result = method(self._instance_id, *args) + if asyncio.iscoroutine(result): + try: + asyncio.get_running_loop() + return run_coro_in_new_loop(result) + except RuntimeError: + loop = get_thread_loop() + return loop.run_until_complete(result) + return result + + @property + def sigma_min(self) -> Any: + return self._call("get_sigma_min") + + @property + def sigma_max(self) -> Any: + return self._call("get_sigma_max") + + @property + def sigma_data(self) -> Any: + return self._call("get_sigma_data") + + @property + def sigmas(self) -> Any: + return self._call("get_sigmas") + + def calculate_input(self, sigma: Any, noise: Any) -> Any: + return self._call("calculate_input", sigma, noise) + + def calculate_denoised( + self, sigma: Any, model_output: Any, model_input: Any + ) -> Any: + return self._call("calculate_denoised", sigma, model_output, model_input) + + def noise_scaling( + self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False + ) -> Any: + return self._call("noise_scaling", sigma, noise, latent_image, max_denoise) + + def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: + return self._call("inverse_noise_scaling", sigma, latent) + + def timestep(self, sigma: Any) -> Any: + return self._call("timestep", sigma) + + def sigma(self, timestep: Any) -> Any: + return self._call("sigma", timestep) + + def percent_to_sigma(self, percent: float) -> Any: + return self._call("percent_to_sigma", percent) + + def set_sigmas(self, sigmas: Any) -> None: + return self._call("set_sigmas", sigmas) diff --git a/comfy/isolation/vae_proxy.py b/comfy/isolation/vae_proxy.py new file mode 100644 index 000000000..8260d06a3 --- /dev/null +++ b/comfy/isolation/vae_proxy.py @@ -0,0 +1,214 @@ +# pylint: disable=attribute-defined-outside-init +import logging +from typing import Any + +from comfy.isolation.proxies.base import ( + IS_CHILD_PROCESS, + BaseProxy, + BaseRegistry, + detach_if_grad, +) +from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry + +logger = logging.getLogger(__name__) + + +class FirstStageModelRegistry(BaseRegistry[Any]): + _type_prefix = "first_stage_model" + + async def get_property(self, instance_id: str, name: str) -> Any: + obj = self._get_instance(instance_id) + return getattr(obj, name) + + async def has_property(self, instance_id: str, name: str) -> bool: + obj = self._get_instance(instance_id) + return hasattr(obj, name) + + +class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]): + _registry_class = FirstStageModelRegistry + __module__ = "comfy.ldm.models.autoencoder" + + def __getattr__(self, name: str) -> Any: + try: + return self._call_rpc("get_property", name) + except Exception as e: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) from e + + def __repr__(self) -> str: + return f"" + + +class VAERegistry(BaseRegistry[Any]): + _type_prefix = "vae" + + async def get_patcher_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return ModelPatcherRegistry().register(vae.patcher) + + async def get_first_stage_model_id(self, instance_id: str) -> str: + vae = self._get_instance(instance_id) + return FirstStageModelRegistry().register(vae.first_stage_model) + + async def encode(self, instance_id: str, pixels: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).encode(pixels)) + + async def encode_tiled( + self, + instance_id: str, + pixels: Any, + tile_x: int = 512, + tile_y: int = 512, + overlap: int = 64, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).encode_tiled( + pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap + ) + ) + + async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs)) + + async def decode_tiled( + self, + instance_id: str, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return detach_if_grad( + self._get_instance(instance_id).decode_tiled( + samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs + ) + ) + + async def get_property(self, instance_id: str, name: str) -> Any: + return getattr(self._get_instance(instance_id), name) + + async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_encode(shape, dtype) + + async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int: + return self._get_instance(instance_id).memory_used_decode(shape, dtype) + + async def process_input(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_input(image)) + + async def process_output(self, instance_id: str, image: Any) -> Any: + return detach_if_grad(self._get_instance(instance_id).process_output(image)) + + +class VAEProxy(BaseProxy[VAERegistry]): + _registry_class = VAERegistry + __module__ = "comfy.sd" + + @property + def patcher(self) -> ModelPatcherProxy: + if not hasattr(self, "_patcher_proxy"): + patcher_id = self._call_rpc("get_patcher_id") + self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False) + return self._patcher_proxy + + @property + def first_stage_model(self) -> FirstStageModelProxy: + if not hasattr(self, "_first_stage_model_proxy"): + fsm_id = self._call_rpc("get_first_stage_model_id") + self._first_stage_model_proxy = FirstStageModelProxy( + fsm_id, manage_lifecycle=False + ) + return self._first_stage_model_proxy + + @property + def vae_dtype(self) -> Any: + return self._get_property("vae_dtype") + + def encode(self, pixels: Any) -> Any: + return self._call_rpc("encode", pixels) + + def encode_tiled( + self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64 + ) -> Any: + return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap) + + def decode(self, samples: Any, **kwargs: Any) -> Any: + return self._call_rpc("decode", samples, **kwargs) + + def decode_tiled( + self, + samples: Any, + tile_x: int = 64, + tile_y: int = 64, + overlap: int = 16, + **kwargs: Any, + ) -> Any: + return self._call_rpc( + "decode_tiled", samples, tile_x, tile_y, overlap, **kwargs + ) + + def get_sd(self) -> Any: + return self._call_rpc("get_sd") + + def _get_property(self, name: str) -> Any: + return self._call_rpc("get_property", name) + + @property + def latent_dim(self) -> int: + return self._get_property("latent_dim") + + @property + def latent_channels(self) -> int: + return self._get_property("latent_channels") + + @property + def downscale_ratio(self) -> Any: + return self._get_property("downscale_ratio") + + @property + def upscale_ratio(self) -> Any: + return self._get_property("upscale_ratio") + + @property + def output_channels(self) -> int: + return self._get_property("output_channels") + + @property + def check_not_vide(self) -> bool: + return self._get_property("not_video") + + @property + def device(self) -> Any: + return self._get_property("device") + + @property + def working_dtypes(self) -> Any: + return self._get_property("working_dtypes") + + @property + def disable_offload(self) -> bool: + return self._get_property("disable_offload") + + @property + def size(self) -> Any: + return self._get_property("size") + + def memory_used_encode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_encode", shape, dtype) + + def memory_used_decode(self, shape: Any, dtype: Any) -> int: + return self._call_rpc("memory_used_decode", shape, dtype) + + def process_input(self, image: Any) -> Any: + return self._call_rpc("process_input", image) + + def process_output(self, image: Any) -> Any: + return self._call_rpc("process_output", image) + + +if not IS_CHILD_PROCESS: + _VAE_REGISTRY_SINGLETON = VAERegistry() + _FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()