feat(isolation-model-proxies): model patcher + clip/vae/model sampling proxies

This commit is contained in:
John Pollock 2026-02-27 12:42:11 -06:00
parent 9ca799362d
commit a1c3124821
6 changed files with 2643 additions and 0 deletions

View File

@ -0,0 +1,327 @@
# pylint: disable=attribute-defined-outside-init,import-outside-toplevel,logging-fstring-interpolation
# CLIP Proxy implementation
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
if TYPE_CHECKING:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
class CondStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "cond_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class CondStageModelProxy(BaseProxy[CondStageModelRegistry]):
_registry_class = CondStageModelRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<CondStageModelProxy {self._instance_id}>"
class TokenizerRegistry(BaseRegistry[Any]):
_type_prefix = "tokenizer"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
class TokenizerProxy(BaseProxy[TokenizerRegistry]):
_registry_class = TokenizerRegistry
__module__ = "comfy.sd"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<TokenizerProxy {self._instance_id}>"
logger = logging.getLogger(__name__)
class CLIPRegistry(BaseRegistry[Any]):
_type_prefix = "clip"
_allowed_setters = {
"layer_idx",
"tokenizer_options",
"use_clip_schedule",
"apply_hooks_to_conds",
}
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def get_patcher_id(self, instance_id: str) -> str:
from comfy.isolation.model_patcher_proxy import ModelPatcherRegistry
return ModelPatcherRegistry().register(self._get_instance(instance_id).patcher)
async def get_cond_stage_model_id(self, instance_id: str) -> str:
return CondStageModelRegistry().register(
self._get_instance(instance_id).cond_stage_model
)
async def get_tokenizer_id(self, instance_id: str) -> str:
return TokenizerRegistry().register(self._get_instance(instance_id).tokenizer)
async def load_model(self, instance_id: str) -> None:
self._get_instance(instance_id).load_model()
async def clip_layer(self, instance_id: str, layer_idx: int) -> None:
self._get_instance(instance_id).clip_layer(layer_idx)
async def set_tokenizer_option(
self, instance_id: str, option_name: str, value: Any
) -> None:
self._get_instance(instance_id).set_tokenizer_option(option_name, value)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def set_property(self, instance_id: str, name: str, value: Any) -> None:
if name not in self._allowed_setters:
raise PermissionError(f"Setting '{name}' is not allowed via RPC")
setattr(self._get_instance(instance_id), name, value)
async def tokenize(
self, instance_id: str, text: str, return_word_ids: bool = False, **kwargs: Any
) -> Any:
return self._get_instance(instance_id).tokenize(
text, return_word_ids=return_word_ids, **kwargs
)
async def encode(self, instance_id: str, text: str) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(text))
async def encode_from_tokens(
self,
instance_id: str,
tokens: Any,
return_pooled: bool = False,
return_dict: bool = False,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens(
tokens, return_pooled=return_pooled, return_dict=return_dict
)
)
async def encode_from_tokens_scheduled(
self,
instance_id: str,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return detach_if_grad(
self._get_instance(instance_id).encode_from_tokens_scheduled(
tokens, unprojected=unprojected, add_dict=add_dict, show_pbar=show_pbar
)
)
async def add_patches(
self,
instance_id: str,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> Any:
return self._get_instance(instance_id).add_patches(
patches, strength_patch=strength_patch, strength_model=strength_model
)
async def get_key_patches(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_key_patches()
async def load_sd(
self, instance_id: str, sd: dict, full_model: bool = False
) -> Any:
return self._get_instance(instance_id).load_sd(sd, full_model=full_model)
async def get_sd(self, instance_id: str) -> Any:
return self._get_instance(instance_id).get_sd()
async def clone(self, instance_id: str) -> str:
return self.register(self._get_instance(instance_id).clone())
class CLIPProxy(BaseProxy[CLIPRegistry]):
_registry_class = CLIPRegistry
__module__ = "comfy.sd"
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
@property
def patcher(self) -> "ModelPatcherProxy":
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@patcher.setter
def patcher(self, value: Any) -> None:
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
if isinstance(value, ModelPatcherProxy):
self._patcher_proxy = value
else:
logger.warning(
f"Attempted to set CLIPProxy.patcher to non-proxy object: {value}"
)
@property
def cond_stage_model(self) -> CondStageModelProxy:
if not hasattr(self, "_cond_stage_model_proxy"):
csm_id = self._call_rpc("get_cond_stage_model_id")
self._cond_stage_model_proxy = CondStageModelProxy(
csm_id, manage_lifecycle=False
)
return self._cond_stage_model_proxy
@property
def tokenizer(self) -> TokenizerProxy:
if not hasattr(self, "_tokenizer_proxy"):
tok_id = self._call_rpc("get_tokenizer_id")
self._tokenizer_proxy = TokenizerProxy(tok_id, manage_lifecycle=False)
return self._tokenizer_proxy
def load_model(self) -> ModelPatcherProxy:
self._call_rpc("load_model")
return self.patcher
@property
def layer_idx(self) -> Optional[int]:
return self._call_rpc("get_property", "layer_idx")
@layer_idx.setter
def layer_idx(self, value: Optional[int]) -> None:
self._call_rpc("set_property", "layer_idx", value)
@property
def tokenizer_options(self) -> dict:
return self._call_rpc("get_property", "tokenizer_options")
@tokenizer_options.setter
def tokenizer_options(self, value: dict) -> None:
self._call_rpc("set_property", "tokenizer_options", value)
@property
def use_clip_schedule(self) -> bool:
return self._call_rpc("get_property", "use_clip_schedule")
@use_clip_schedule.setter
def use_clip_schedule(self, value: bool) -> None:
self._call_rpc("set_property", "use_clip_schedule", value)
@property
def apply_hooks_to_conds(self) -> Any:
return self._call_rpc("get_property", "apply_hooks_to_conds")
@apply_hooks_to_conds.setter
def apply_hooks_to_conds(self, value: Any) -> None:
self._call_rpc("set_property", "apply_hooks_to_conds", value)
def clip_layer(self, layer_idx: int) -> None:
return self._call_rpc("clip_layer", layer_idx)
def set_tokenizer_option(self, option_name: str, value: Any) -> None:
return self._call_rpc("set_tokenizer_option", option_name, value)
def tokenize(self, text: str, return_word_ids: bool = False, **kwargs: Any) -> Any:
return self._call_rpc(
"tokenize", text, return_word_ids=return_word_ids, **kwargs
)
def encode(self, text: str) -> Any:
return self._call_rpc("encode", text)
def encode_from_tokens(
self, tokens: Any, return_pooled: bool = False, return_dict: bool = False
) -> Any:
res = self._call_rpc(
"encode_from_tokens",
tokens,
return_pooled=return_pooled,
return_dict=return_dict,
)
if return_pooled and isinstance(res, list) and not return_dict:
return tuple(res)
return res
def encode_from_tokens_scheduled(
self,
tokens: Any,
unprojected: bool = False,
add_dict: Optional[dict] = None,
show_pbar: bool = True,
) -> Any:
add_dict = add_dict or {}
return self._call_rpc(
"encode_from_tokens_scheduled",
tokens,
unprojected=unprojected,
add_dict=add_dict,
show_pbar=show_pbar,
)
def add_patches(
self, patches: Any, strength_patch: float = 1.0, strength_model: float = 1.0
) -> Any:
return self._call_rpc(
"add_patches",
patches,
strength_patch=strength_patch,
strength_model=strength_model,
)
def get_key_patches(self) -> Any:
return self._call_rpc("get_key_patches")
def load_sd(self, sd: dict, full_model: bool = False) -> Any:
return self._call_rpc("load_sd", sd, full_model=full_model)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def clone(self) -> CLIPProxy:
new_id = self._call_rpc("clone")
return CLIPProxy(new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS)
if not IS_CHILD_PROCESS:
_CLIP_REGISTRY_SINGLETON = CLIPRegistry()
_COND_STAGE_MODEL_REGISTRY_SINGLETON = CondStageModelRegistry()
_TOKENIZER_REGISTRY_SINGLETON = TokenizerRegistry()

View File

@ -0,0 +1,820 @@
# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
# RPC proxy for ModelPatcher (parent process)
from __future__ import annotations
import logging
from typing import Any, Optional, List, Set, Dict, Callable
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
)
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
AutoPatcherEjector,
)
logger = logging.getLogger(__name__)
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
_registry_class = ModelPatcherRegistry
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
else:
self._rpc_caller = self._registry
return self._rpc_caller
def get_all_callbacks(self, call_type: str = None) -> Any:
return self._call_rpc("get_all_callbacks", call_type)
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
return self._call_rpc("get_all_wrappers", wrapper_type)
def _load_list(self, *args, **kwargs) -> Any:
return self._call_rpc("load_list_internal", *args, **kwargs)
def prepare_hook_patches_current_keyframe(
self, t: Any, hook_group: Any, model_options: Any
) -> None:
self._call_rpc(
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
)
def add_hook_patches(
self,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
self._call_rpc(
"add_hook_patches", hook, patches, strength_patch, strength_model
)
def clear_cached_hook_weights(self) -> None:
self._call_rpc("clear_cached_hook_weights")
def get_combined_hook_patches(self, hooks: Any) -> Any:
return self._call_rpc("get_combined_hook_patches", hooks)
def get_additional_models_with_key(self, key: str) -> Any:
return self._call_rpc("get_additional_models_with_key", key)
@property
def object_patches(self) -> Any:
return self._call_rpc("get_object_patches")
@property
def patches(self) -> Any:
res = self._call_rpc("get_patches")
if isinstance(res, dict):
new_res = {}
for k, v in res.items():
new_list = []
for item in v:
if isinstance(item, list):
new_list.append(tuple(item))
else:
new_list.append(item)
new_res[k] = new_list
return new_res
return res
@property
def pinned(self) -> Set:
val = self._call_rpc("get_patcher_attr", "pinned")
return set(val) if val is not None else set()
@property
def hook_patches(self) -> Dict:
val = self._call_rpc("get_patcher_attr", "hook_patches")
if val is None:
return {}
try:
from comfy.hooks import _HookRef
import json
new_val = {}
for k, v in val.items():
if isinstance(k, str):
if k.startswith("PYISOLATE_HOOKREF:"):
ref_id = k.split(":", 1)[1]
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
elif k.startswith("__pyisolate_key__"):
try:
json_str = k[len("__pyisolate_key__") :]
data = json.loads(json_str)
ref_id = None
if isinstance(data, list):
for item in data:
if (
isinstance(item, list)
and len(item) == 2
and item[0] == "id"
):
ref_id = item[1]
break
if ref_id:
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
else:
new_val[k] = v
except Exception:
new_val[k] = v
else:
new_val[k] = v
else:
new_val[k] = v
return new_val
except ImportError:
return val
def set_hook_mode(self, hook_mode: Any) -> None:
self._call_rpc("set_hook_mode", hook_mode)
def register_all_hook_patches(
self,
hooks: Any,
target_dict: Any,
model_options: Any = None,
registered: Any = None,
) -> None:
self._call_rpc(
"register_all_hook_patches", hooks, target_dict, model_options, registered
)
def is_clone(self, other: Any) -> bool:
if isinstance(other, ModelPatcherProxy):
return self._call_rpc("is_clone_by_id", other._instance_id)
return False
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
if not IS_CHILD_PROCESS:
return self._call_rpc("is_clone", clone)
return False
def get_model_object(self, name: str) -> Any:
return self._call_rpc("get_model_object", name)
@property
def model_options(self) -> dict:
data = self._call_rpc("get_model_options")
import json
def _decode_keys(obj):
if isinstance(obj, dict):
new_d = {}
for k, v in obj.items():
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
try:
json_str = k[17:]
val = json.loads(json_str)
if isinstance(val, list):
val = tuple(val)
new_d[val] = _decode_keys(v)
except:
new_d[k] = _decode_keys(v)
else:
new_d[k] = _decode_keys(v)
return new_d
if isinstance(obj, list):
return [_decode_keys(x) for x in obj]
return obj
return _decode_keys(data)
@model_options.setter
def model_options(self, value: dict) -> None:
self._call_rpc("set_model_options", value)
def apply_hooks(self, hooks: Any) -> Any:
return self._call_rpc("apply_hooks", hooks)
def prepare_state(self, timestep: Any) -> Any:
return self._call_rpc("prepare_state", timestep)
def restore_hook_patches(self) -> None:
self._call_rpc("restore_hook_patches")
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
self._call_rpc("unpatch_hooks", whitelist_keys_set)
def model_patches_to(self, device: Any) -> Any:
return self._call_rpc("model_patches_to", device)
def partially_load(
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
) -> Any:
return self._call_rpc(
"partially_load", device, extra_memory, force_patch_weights
)
def partially_unload(
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
) -> int:
return self._call_rpc(
"partially_unload", device_to, memory_to_free, force_patch_weights
)
def load(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._call_rpc(
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
)
def patch_model(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> Any:
self._call_rpc(
"patch_model",
device_to,
lowvram_model_memory,
load_weights,
force_patch_weights,
)
return self
def unpatch_model(
self, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._call_rpc("unpatch_model", device_to, unpatch_weights)
def detach(self, unpatch_all: bool = True) -> Any:
self._call_rpc("detach", unpatch_all)
return self.model
def _cpu_tensor_bytes(self, obj: Any) -> int:
import torch
if isinstance(obj, torch.Tensor):
if obj.device.type == "cpu":
return obj.nbytes
return 0
if isinstance(obj, dict):
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
if isinstance(obj, (list, tuple)):
return sum(self._cpu_tensor_bytes(v) for v in obj)
return 0
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
if required_bytes <= 0:
return True
import torch
import comfy.model_management as model_management
target_raw = self.load_device
try:
if isinstance(target_raw, torch.device):
target = target_raw
elif isinstance(target_raw, str):
target = torch.device(target_raw)
elif isinstance(target_raw, int):
target = torch.device(f"cuda:{target_raw}")
else:
target = torch.device(target_raw)
except Exception:
return True
if target.type != "cuda":
return True
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
if model_management.get_free_memory(target) >= required:
return True
model_management.cleanup_models_gc()
model_management.cleanup_models()
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.free_memory(required, target, for_dynamic=True)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
model_management.free_memory(required, target, for_dynamic=False)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.load_models_gpu(
[self],
minimum_memory_required=required,
)
return model_management.get_free_memory(target) >= required
def apply_model(self, *args, **kwargs) -> Any:
import torch
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
def _to_cuda(obj: Any) -> Any:
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
return obj.to("cuda")
if isinstance(obj, dict):
return {k: _to_cuda(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cuda(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cuda(v) for v in obj)
return obj
try:
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
except torch.OutOfMemoryError:
self._ensure_apply_model_headroom(required_bytes)
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
return dict.fromkeys(keys, None)
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
res = self._call_rpc("add_patches", *args, **kwargs)
if isinstance(res, list):
return [tuple(x) if isinstance(x, list) else x for x in res]
return res
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
return self._call_rpc("get_key_patches", filter_prefix)
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
def pin_weight_to_device(self, key):
self._call_rpc("pin_weight_to_device", key)
def unpin_weight(self, key):
self._call_rpc("unpin_weight", key)
def unpin_all_weights(self):
self._call_rpc("unpin_all_weights")
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
return self._call_rpc(
"calculate_weight", patches, weight, key, intermediate_dtype
)
def inject_model(self) -> None:
self._call_rpc("inject_model")
def eject_model(self) -> None:
self._call_rpc("eject_model")
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
return AutoPatcherEjector(
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
)
@property
def is_injected(self) -> bool:
return self._call_rpc("get_is_injected")
@property
def skip_injection(self) -> bool:
return self._call_rpc("get_skip_injection")
@skip_injection.setter
def skip_injection(self, value: bool) -> None:
self._call_rpc("set_skip_injection", value)
def clean_hooks(self) -> None:
self._call_rpc("clean_hooks")
def pre_run(self) -> None:
self._call_rpc("pre_run")
def cleanup(self) -> None:
try:
self._call_rpc("cleanup")
except Exception:
logger.debug(
"ModelPatcherProxy cleanup RPC failed for %s",
self._instance_id,
exc_info=True,
)
finally:
super().cleanup()
@property
def model(self) -> _InnerModelProxy:
return _InnerModelProxy(self)
def __getattr__(self, name: str) -> Any:
_whitelisted_attrs = {
"hook_patches_backup",
"hook_backup",
"cached_hook_patches",
"current_hooks",
"forced_hooks",
"is_clip",
"patches_uuid",
"pinned",
"attachments",
"additional_models",
"injections",
"hook_patches",
"model_lowvram",
"model_loaded_weight_memory",
"backup",
"object_patches_backup",
"weight_wrapper_patches",
"weight_inplace_update",
"force_cast_weights",
}
if name in _whitelisted_attrs:
return self._call_rpc("get_patcher_attr", name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def load_lora(
self,
lora_path: str,
strength_model: float,
clip: Optional[Any] = None,
strength_clip: float = 1.0,
) -> tuple:
clip_id = None
if clip is not None:
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
result = self._call_rpc(
"load_lora", lora_path, strength_model, clip_id, strength_clip
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
new_clip = CLIPProxy(result["clip_id"])
return (new_model, new_clip)
@property
def load_device(self) -> Any:
return self._call_rpc("get_load_device")
@property
def offload_device(self) -> Any:
return self._call_rpc("get_offload_device")
@property
def device(self) -> Any:
return self.load_device
def current_loaded_device(self) -> Any:
return self._call_rpc("current_loaded_device")
@property
def size(self) -> int:
return self._call_rpc("get_size")
def model_size(self) -> Any:
return self._call_rpc("model_size")
def loaded_size(self) -> Any:
return self._call_rpc("loaded_size")
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
def lowvram_patch_counter(self) -> int:
return self._call_rpc("lowvram_patch_counter")
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
def get_free_memory(self, device: Any) -> Any:
return self._call_rpc("get_free_memory", device)
def partially_unload_ram(self, ram_to_unload: int) -> Any:
return self._call_rpc("partially_unload_ram", ram_to_unload)
def model_dtype(self) -> Any:
res = self._call_rpc("model_dtype")
if isinstance(res, str) and res.startswith("torch."):
try:
import torch
attr = res.split(".")[-1]
if hasattr(torch, attr):
return getattr(torch, attr)
except ImportError:
pass
return res
@property
def hook_mode(self) -> Any:
return self._call_rpc("get_hook_mode")
@hook_mode.setter
def hook_mode(self, value: Any) -> None:
self._call_rpc("set_hook_mode", value)
def set_model_sampler_cfg_function(
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_cfg_function",
sampler_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_post_cfg_function(
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_post_cfg_function",
post_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_pre_cfg_function(
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_pre_cfg_function",
pre_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
def set_model_patch(self, patch: Any, name: str) -> None:
self._call_rpc("set_model_patch", patch, name)
def set_model_patch_replace(
self,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._call_rpc(
"set_model_patch_replace",
patch,
name,
block_name,
number,
transformer_index,
)
def set_model_attn1_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn1", block_name, number, transformer_index
)
def set_model_attn2_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn2", block_name, number, transformer_index
)
def set_model_attn1_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "output_block_patch")
def set_model_emb_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "emb_patch")
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "post_input")
def set_model_rope_options(
self,
scale_x=1.0,
shift_x=0.0,
scale_y=1.0,
shift_y=0.0,
scale_t=1.0,
shift_t=0.0,
**kwargs: Any,
) -> None:
options = {
"scale_x": scale_x,
"shift_x": shift_x,
"scale_y": scale_y,
"shift_y": shift_y,
"scale_t": scale_t,
"shift_t": shift_t,
}
options.update(kwargs)
self._call_rpc("set_model_rope_options", options)
def set_model_compute_dtype(self, dtype: Any) -> None:
self._call_rpc("set_model_compute_dtype", dtype)
def add_object_patch(self, name: str, obj: Any) -> None:
self._call_rpc("add_object_patch", name, obj)
def add_weight_wrapper(self, name: str, function: Any) -> None:
self._call_rpc("add_weight_wrapper", name, function)
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
@property
def wrappers(self) -> Any:
return self._call_rpc("get_wrappers")
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
self._call_rpc("add_callback_with_key", call_type, key, callback)
def add_callback(self, call_type: str, callback: Any) -> None:
self.add_callback_with_key(call_type, None, callback)
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
self._call_rpc("remove_callbacks_with_key", call_type, key)
@property
def callbacks(self) -> Any:
return self._call_rpc("get_callbacks")
def set_attachments(self, key: str, attachment: Any) -> None:
self._call_rpc("set_attachments", key, attachment)
def get_attachment(self, key: str) -> Any:
return self._call_rpc("get_attachment", key)
def remove_attachments(self, key: str) -> None:
self._call_rpc("remove_attachments", key)
def set_injections(self, key: str, injections: Any) -> None:
self._call_rpc("set_injections", key, injections)
def get_injections(self, key: str) -> Any:
return self._call_rpc("get_injections", key)
def remove_injections(self, key: str) -> None:
self._call_rpc("remove_injections", key)
def set_additional_models(self, key: str, models: Any) -> None:
ids = [m._instance_id for m in models]
self._call_rpc("set_additional_models", key, ids)
def remove_additional_models(self, key: str) -> None:
self._call_rpc("remove_additional_models", key)
def get_nested_additional_models(self) -> Any:
return self._call_rpc("get_nested_additional_models")
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@property
def parent(self) -> Any:
return self._call_rpc("get_parent")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
):
return self._parent._call_rpc("get_inner_model_attr", name)
if name == "load_device":
return self._parent._call_rpc("get_inner_model_attr", "load_device")
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling")
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k
)
if name == "extra_conds":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds", a, k
)
if name == "memory_required":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_memory_required", a, k
)
if name == "apply_model":
# Delegate to parent's method to get the CPU->CUDA optimization
return self._parent.apply_model
if name == "process_latent_in":
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
if name == "process_latent_out":
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
if name == "scale_latent_inpaint":
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
if name == "diffusion_model":
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
raise AttributeError(f"'{name}' not supported on isolated InnerModel")

View File

@ -0,0 +1,875 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access,unused-import
# RPC server for ModelPatcher isolation (child process)
from __future__ import annotations
import gc
import logging
from typing import Any, Optional, List
try:
from comfy.model_patcher import AutoPatcherEjector
except ImportError:
class AutoPatcherEjector:
def __init__(self, model, skip_and_inject_on_exit_only=False):
self.model = model
self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only
self.prev_skip_injection = False
self.was_injected = False
def __enter__(self):
self.was_injected = False
self.prev_skip_injection = self.model.skip_injection
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = True
if self.model.is_injected:
self.model.eject_model()
self.was_injected = True
def __exit__(self, *args):
if self.skip_and_inject_on_exit_only:
self.model.skip_injection = self.prev_skip_injection
self.model.inject_model()
if self.was_injected and not self.model.skip_injection:
self.model.inject_model()
self.model.skip_injection = self.prev_skip_injection
from comfy.isolation.proxies.base import (
BaseRegistry,
detach_if_grad,
)
logger = logging.getLogger(__name__)
class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model"
def __init__(self) -> None:
super().__init__()
self._pending_cleanup_ids: set[str] = set()
async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id)
new_model = instance.clone()
return self.register(new_model)
async def is_clone(self, instance_id: str, other: Any) -> bool:
instance = self._get_instance(instance_id)
if hasattr(other, "model"):
return instance.is_clone(other)
return False
async def get_model_object(self, instance_id: str, name: str) -> Any:
instance = self._get_instance(instance_id)
if name == "model":
return f"<ModelObject: {type(instance.model).__name__}>"
result = instance.get_model_object(name)
if name == "model_sampling":
from comfy.isolation.model_sampling_proxy import (
ModelSamplingRegistry,
ModelSamplingProxy,
)
registry = ModelSamplingRegistry()
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry)
return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict:
instance = self._get_instance(instance_id)
import copy
opts = copy.deepcopy(instance.model_options)
return self._sanitize_rpc_result(opts)
async def set_model_options(self, instance_id: str, options: dict) -> None:
self._get_instance(instance_id).model_options = options
async def get_patcher_attr(self, instance_id: str, name: str) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), name, None)
)
async def model_state_dict(self, instance_id: str, filter_prefix=None) -> Any:
instance = self._get_instance(instance_id)
sd_keys = instance.model.state_dict().keys()
return dict.fromkeys(sd_keys, None)
def _sanitize_rpc_result(self, obj, seen=None):
if seen is None:
seen = set()
if obj is None:
return None
if isinstance(obj, (bool, int, float, str)):
if isinstance(obj, str) and len(obj) > 500000:
return f"<Truncated String len={len(obj)}>"
return obj
obj_id = id(obj)
if obj_id in seen:
return None
seen.add(obj_id)
if isinstance(obj, (list, tuple)):
return [self._sanitize_rpc_result(x, seen) for x in obj]
if isinstance(obj, set):
return [self._sanitize_rpc_result(x, seen) for x in obj]
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if isinstance(k, tuple):
import json
try:
key_str = "__pyisolate_key__" + json.dumps(list(k))
new_dict[key_str] = self._sanitize_rpc_result(v, seen)
except Exception:
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
else:
new_dict[str(k)] = self._sanitize_rpc_result(v, seen)
return new_dict
if (
hasattr(obj, "__dict__")
and not hasattr(obj, "__get__")
and not hasattr(obj, "__call__")
):
return self._sanitize_rpc_result(obj.__dict__, seen)
if hasattr(obj, "items") and hasattr(obj, "get"):
return {str(k): self._sanitize_rpc_result(v, seen) for k, v in obj.items()}
return None
async def get_load_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).load_device
async def get_offload_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).offload_device
async def current_loaded_device(self, instance_id: str) -> Any:
return self._get_instance(instance_id).current_loaded_device()
async def get_size(self, instance_id: str) -> int:
return self._get_instance(instance_id).size
async def model_size(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_size()
async def loaded_size(self, instance_id: str) -> Any:
return self._get_instance(instance_id).loaded_size()
async def get_ram_usage(self, instance_id: str) -> int:
return self._get_instance(instance_id).get_ram_usage()
async def lowvram_patch_counter(self, instance_id: str) -> int:
return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape)
async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id)
if hasattr(instance, "is_dynamic"):
return bool(instance.is_dynamic())
return False
async def get_free_memory(self, instance_id: str, device: Any) -> Any:
instance = self._get_instance(instance_id)
if hasattr(instance, "get_free_memory"):
return instance.get_free_memory(device)
import comfy.model_management
return comfy.model_management.get_free_memory(device)
async def partially_unload_ram(self, instance_id: str, ram_to_unload: int) -> Any:
instance = self._get_instance(instance_id)
if hasattr(instance, "partially_unload_ram"):
return instance.partially_unload_ram(ram_to_unload)
return None
async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype()
async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device)
async def partially_load(
self,
instance_id: str,
device: Any,
extra_memory: Any,
force_patch_weights: bool = False,
) -> Any:
return self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
)
async def partially_unload(
self,
instance_id: str,
device_to: Any,
memory_to_free: int = 0,
force_patch_weights: bool = False,
) -> int:
return self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
)
async def load(
self,
instance_id: str,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
)
async def patch_model(
self,
instance_id: str,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> None:
try:
self._get_instance(instance_id).patch_model(
device_to, lowvram_model_memory, load_weights, force_patch_weights
)
except AttributeError as e:
logger.error(
f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
)
return
async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all)
async def prepare_state(self, instance_id: str, timestep: Any) -> Any:
instance = self._get_instance(instance_id)
cp = getattr(instance.model, "current_patcher", instance)
if cp is None:
cp = instance
return cp.prepare_state(timestep)
async def pre_run(self, instance_id: str) -> None:
self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None:
try:
instance = self._get_instance(instance_id)
except Exception:
logger.debug(
"ModelPatcher cleanup requested for missing instance %s",
instance_id,
exc_info=True,
)
return
try:
instance.cleanup()
finally:
with self._lock:
self._pending_cleanup_ids.add(instance_id)
gc.collect()
def sweep_pending_cleanup(self) -> int:
removed = 0
with self._lock:
pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear()
for instance_id in pending_ids:
instance = self._registry.pop(instance_id, None)
if instance is None:
continue
self._id_map.pop(id(instance), None)
removed += 1
gc.collect()
return removed
def purge_all(self) -> int:
with self._lock:
removed = len(self._registry)
self._registry.clear()
self._id_map.clear()
self._pending_cleanup_ids.clear()
gc.collect()
return removed
async def apply_hooks(self, instance_id: str, hooks: Any) -> Any:
instance = self._get_instance(instance_id)
cp = getattr(instance.model, "current_patcher", instance)
if cp is None:
cp = instance
return cp.apply_hooks(hooks=hooks)
async def clean_hooks(self, instance_id: str) -> None:
self._get_instance(instance_id).clean_hooks()
async def restore_hook_patches(self, instance_id: str) -> None:
self._get_instance(instance_id).restore_hook_patches()
async def unpatch_hooks(
self, instance_id: str, whitelist_keys_set: Optional[set] = None
) -> None:
self._get_instance(instance_id).unpatch_hooks(whitelist_keys_set)
async def register_all_hook_patches(
self,
instance_id: str,
hooks: Any,
target_dict: Any,
model_options: Any,
registered: Any,
) -> None:
from types import SimpleNamespace
import comfy.hooks
instance = self._get_instance(instance_id)
if isinstance(hooks, SimpleNamespace) or hasattr(hooks, "__dict__"):
hook_data = hooks.__dict__ if hasattr(hooks, "__dict__") else hooks
new_hooks = comfy.hooks.HookGroup()
if hasattr(hook_data, "hooks"):
new_hooks.hooks = (
hook_data["hooks"]
if isinstance(hook_data, dict)
else hook_data.hooks
)
hooks = new_hooks
instance.register_all_hook_patches(
hooks, target_dict, model_options, registered
)
async def get_hook_mode(self, instance_id: str) -> Any:
return getattr(self._get_instance(instance_id), "hook_mode", None)
async def set_hook_mode(self, instance_id: str, value: Any) -> None:
setattr(self._get_instance(instance_id), "hook_mode", value)
async def inject_model(self, instance_id: str) -> None:
instance = self._get_instance(instance_id)
try:
instance.inject_model()
except AttributeError as e:
if "inject" in str(e):
logger.error(
"Isolation Error: Injector object lost method code during serialization. Cannot inject. Skipping."
)
return
raise e
async def eject_model(self, instance_id: str) -> None:
self._get_instance(instance_id).eject_model()
async def get_is_injected(self, instance_id: str) -> bool:
return self._get_instance(instance_id).is_injected
async def set_skip_injection(self, instance_id: str, value: bool) -> None:
self._get_instance(instance_id).skip_injection = value
async def get_skip_injection(self, instance_id: str) -> bool:
return self._get_instance(instance_id).skip_injection
async def set_model_sampler_cfg_function(
self,
instance_id: str,
sampler_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
if not callable(sampler_cfg_function):
logger.error(
f"set_model_sampler_cfg_function: Expected callable, got {type(sampler_cfg_function)}. Skipping."
)
return
self._get_instance(instance_id).set_model_sampler_cfg_function(
sampler_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_post_cfg_function(
self,
instance_id: str,
post_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
self._get_instance(instance_id).set_model_sampler_post_cfg_function(
post_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_pre_cfg_function(
self,
instance_id: str,
pre_cfg_function: Any,
disable_cfg1_optimization: bool = False,
) -> None:
self._get_instance(instance_id).set_model_sampler_pre_cfg_function(
pre_cfg_function, disable_cfg1_optimization
)
async def set_model_sampler_calc_cond_batch_function(
self, instance_id: str, fn: Any
) -> None:
self._get_instance(instance_id).set_model_sampler_calc_cond_batch_function(fn)
async def set_model_unet_function_wrapper(
self, instance_id: str, unet_wrapper_function: Any
) -> None:
self._get_instance(instance_id).set_model_unet_function_wrapper(
unet_wrapper_function
)
async def set_model_denoise_mask_function(
self, instance_id: str, denoise_mask_function: Any
) -> None:
self._get_instance(instance_id).set_model_denoise_mask_function(
denoise_mask_function
)
async def set_model_patch(self, instance_id: str, patch: Any, name: str) -> None:
self._get_instance(instance_id).set_model_patch(patch, name)
async def set_model_patch_replace(
self,
instance_id: str,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._get_instance(instance_id).set_model_patch_replace(
patch, name, block_name, number, transformer_index
)
async def set_model_input_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_input_block_patch(patch)
async def set_model_input_block_patch_after_skip(
self, instance_id: str, patch: Any
) -> None:
self._get_instance(instance_id).set_model_input_block_patch_after_skip(patch)
async def set_model_output_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_output_block_patch(patch)
async def set_model_emb_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_emb_patch(patch)
async def set_model_forward_timestep_embed_patch(
self, instance_id: str, patch: Any
) -> None:
self._get_instance(instance_id).set_model_forward_timestep_embed_patch(patch)
async def set_model_double_block_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_double_block_patch(patch)
async def set_model_post_input_patch(self, instance_id: str, patch: Any) -> None:
self._get_instance(instance_id).set_model_post_input_patch(patch)
async def set_model_rope_options(self, instance_id: str, options: dict) -> None:
self._get_instance(instance_id).set_model_rope_options(**options)
async def set_model_compute_dtype(self, instance_id: str, dtype: Any) -> None:
self._get_instance(instance_id).set_model_compute_dtype(dtype)
async def clone_has_same_weights_by_id(
self, instance_id: str, other_id: str
) -> bool:
instance = self._get_instance(instance_id)
other = self._get_instance(other_id)
if not other:
return False
return instance.clone_has_same_weights(other)
async def load_list_internal(self, instance_id: str, *args, **kwargs) -> Any:
return self._get_instance(instance_id)._load_list(*args, **kwargs)
async def is_clone_by_id(self, instance_id: str, other_id: str) -> bool:
instance = self._get_instance(instance_id)
other = self._get_instance(other_id)
if hasattr(instance, "is_clone"):
return instance.is_clone(other)
return False
async def add_object_patch(self, instance_id: str, name: str, obj: Any) -> None:
self._get_instance(instance_id).add_object_patch(name, obj)
async def add_weight_wrapper(
self, instance_id: str, name: str, function: Any
) -> None:
self._get_instance(instance_id).add_weight_wrapper(name, function)
async def add_wrapper_with_key(
self, instance_id: str, wrapper_type: Any, key: str, fn: Any
) -> None:
self._get_instance(instance_id).add_wrapper_with_key(wrapper_type, key, fn)
async def remove_wrappers_with_key(
self, instance_id: str, wrapper_type: str, key: str
) -> None:
self._get_instance(instance_id).remove_wrappers_with_key(wrapper_type, key)
async def get_wrappers(
self, instance_id: str, wrapper_type: str = None, key: str = None
) -> Any:
if wrapper_type is None and key is None:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "wrappers", {})
)
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_wrappers(wrapper_type, key)
)
async def get_all_wrappers(self, instance_id: str, wrapper_type: str = None) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "get_all_wrappers", lambda x: [])(
wrapper_type
)
)
async def add_callback_with_key(
self, instance_id: str, call_type: str, key: str, callback: Any
) -> None:
self._get_instance(instance_id).add_callback_with_key(call_type, key, callback)
async def remove_callbacks_with_key(
self, instance_id: str, call_type: str, key: str
) -> None:
self._get_instance(instance_id).remove_callbacks_with_key(call_type, key)
async def get_callbacks(
self, instance_id: str, call_type: str = None, key: str = None
) -> Any:
if call_type is None and key is None:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "callbacks", {})
)
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_callbacks(call_type, key)
)
async def get_all_callbacks(self, instance_id: str, call_type: str = None) -> Any:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id), "get_all_callbacks", lambda x: [])(
call_type
)
)
async def set_attachments(
self, instance_id: str, key: str, attachment: Any
) -> None:
self._get_instance(instance_id).set_attachments(key, attachment)
async def get_attachment(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_attachment(key)
)
async def remove_attachments(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_attachments(key)
async def set_injections(self, instance_id: str, key: str, injections: Any) -> None:
self._get_instance(instance_id).set_injections(key, injections)
async def get_injections(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_injections(key)
)
async def remove_injections(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_injections(key)
async def set_additional_models(
self, instance_id: str, key: str, models: Any
) -> None:
self._get_instance(instance_id).set_additional_models(key, models)
async def remove_additional_models(self, instance_id: str, key: str) -> None:
self._get_instance(instance_id).remove_additional_models(key)
async def get_nested_additional_models(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_nested_additional_models()
)
async def get_additional_models(self, instance_id: str) -> List[str]:
models = self._get_instance(instance_id).get_additional_models()
return [self.register(m) for m in models]
async def get_additional_models_with_key(self, instance_id: str, key: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).get_additional_models_with_key(key)
)
async def model_patches_models(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).model_patches_models()
)
async def get_patches(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(self._get_instance(instance_id).patches.copy())
async def get_object_patches(self, instance_id: str) -> Any:
return self._sanitize_rpc_result(
self._get_instance(instance_id).object_patches.copy()
)
async def add_patches(
self,
instance_id: str,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> Any:
return self._get_instance(instance_id).add_patches(
patches, strength_patch, strength_model
)
async def get_key_patches(
self, instance_id: str, filter_prefix: Optional[str] = None
) -> Any:
res = self._get_instance(instance_id).get_key_patches()
if filter_prefix:
res = {k: v for k, v in res.items() if k.startswith(filter_prefix)}
safe_res = {}
for k, v in res.items():
safe_res[k] = [
f"<Tensor shape={t.shape} dtype={t.dtype}>"
if hasattr(t, "shape")
else str(t)
for t in v
]
return safe_res
async def add_hook_patches(
self,
instance_id: str,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
if hasattr(hook, "hook_ref") and isinstance(hook.hook_ref, dict):
try:
hook.hook_ref = tuple(sorted(hook.hook_ref.items()))
except Exception:
hook.hook_ref = None
self._get_instance(instance_id).add_hook_patches(
hook, patches, strength_patch, strength_model
)
async def get_combined_hook_patches(self, instance_id: str, hooks: Any) -> Any:
if hooks is not None and hasattr(hooks, "hooks"):
for hook in getattr(hooks, "hooks", []):
hook_ref = getattr(hook, "hook_ref", None)
if isinstance(hook_ref, dict):
try:
hook.hook_ref = tuple(sorted(hook_ref.items()))
except Exception:
hook.hook_ref = None
res = self._get_instance(instance_id).get_combined_hook_patches(hooks)
return self._sanitize_rpc_result(res)
async def clear_cached_hook_weights(self, instance_id: str) -> None:
self._get_instance(instance_id).clear_cached_hook_weights()
async def prepare_hook_patches_current_keyframe(
self, instance_id: str, t: Any, hook_group: Any, model_options: Any
) -> None:
self._get_instance(instance_id).prepare_hook_patches_current_keyframe(
t, hook_group, model_options
)
async def get_parent(self, instance_id: str) -> Any:
return getattr(self._get_instance(instance_id), "parent", None)
async def patch_weight_to_device(
self,
instance_id: str,
key: str,
device_to: Any = None,
inplace_update: bool = False,
) -> None:
self._get_instance(instance_id).patch_weight_to_device(
key, device_to, inplace_update
)
async def pin_weight_to_device(self, instance_id: str, key: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.pin_weight_to_device(key)
async def unpin_weight(self, instance_id: str, key: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.unpin_weight(key)
async def unpin_all_weights(self, instance_id: str) -> None:
instance = self._get_instance(instance_id)
if hasattr(instance, "pinned") and isinstance(instance.pinned, list):
instance.pinned = set(instance.pinned)
instance.unpin_all_weights()
async def calculate_weight(
self,
instance_id: str,
patches: Any,
weight: Any,
key: str,
intermediate_dtype: Any = float,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).calculate_weight(
patches, weight, key, intermediate_dtype
)
)
async def get_inner_model_attr(self, instance_id: str, name: str) -> Any:
try:
return self._sanitize_rpc_result(
getattr(self._get_instance(instance_id).model, name)
)
except AttributeError:
return None
async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs)
async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs)
async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
sd = self._get_instance(instance_id).model.state_dict(*args, **kwargs)
return {
k: {"numel": v.numel(), "element_size": v.element_size()}
for k, v in sd.items()
}
async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result))
async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
)
async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.process_latent_out(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict
) -> Any:
instance = self._get_instance(instance_id)
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try:
target = None
if args and hasattr(args[0], "device"):
target = args[0].device
elif kwargs:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
if target is not None and hasattr(result, "to"):
return detach_if_grad(result.to(target))
except Exception:
logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
return detach_if_grad(result)
async def load_lora(
self,
instance_id: str,
lora_path: str,
strength_model: float,
clip_id: Optional[str] = None,
strength_clip: float = 1.0,
) -> dict:
import comfy.utils
import comfy.sd
import folder_paths
from comfy.isolation.clip_proxy import CLIPRegistry
model = self._get_instance(instance_id)
clip = None
if clip_id:
clip = CLIPRegistry()._get_instance(clip_id)
lora_full_path = folder_paths.get_full_path("loras", lora_path)
if lora_full_path is None:
raise ValueError(f"LoRA file not found: {lora_path}")
lora = comfy.utils.load_torch_file(lora_full_path)
new_model, new_clip = comfy.sd.load_lora_for_models(
model, clip, lora, strength_model, strength_clip
)
new_model_id = self.register(new_model) if new_model else None
new_clip_id = (
CLIPRegistry().register(new_clip) if (new_clip and clip_id) else None
)
return {"model_id": new_model_id, "clip_id": new_clip_id}

View File

@ -0,0 +1,154 @@
# pylint: disable=import-outside-toplevel,logging-fstring-interpolation,protected-access
# Isolation utilities and serializers for ModelPatcherProxy
from __future__ import annotations
import logging
import os
from typing import Any
logger = logging.getLogger(__name__)
def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
if not isolation_active:
return model_patcher
if is_child:
return model_patcher
if isinstance(model_patcher, ModelPatcherProxy):
return model_patcher
registry = ModelPatcherRegistry()
model_id = registry.register(model_patcher)
logger.debug(f"Isolated ModelPatcher: {model_id}")
return ModelPatcherProxy(model_id, registry, manage_lifecycle=True)
def register_hooks_serializers(registry=None):
from pyisolate._internal.serialization_registry import SerializerRegistry
import comfy.hooks
if registry is None:
registry = SerializerRegistry.get_instance()
def serialize_enum(obj):
return {"__enum__": f"{type(obj).__name__}.{obj.name}"}
def deserialize_enum(data):
cls_name, val_name = data["__enum__"].split(".")
cls = getattr(comfy.hooks, cls_name)
return cls[val_name]
registry.register("EnumHookType", serialize_enum, deserialize_enum)
registry.register("EnumHookScope", serialize_enum, deserialize_enum)
registry.register("EnumHookMode", serialize_enum, deserialize_enum)
registry.register("EnumWeightTarget", serialize_enum, deserialize_enum)
def serialize_hook_group(obj):
return {"__type__": "HookGroup", "hooks": obj.hooks}
def deserialize_hook_group(data):
hg = comfy.hooks.HookGroup()
for h in data["hooks"]:
hg.add(h)
return hg
registry.register("HookGroup", serialize_hook_group, deserialize_hook_group)
def serialize_dict_state(obj):
d = obj.__dict__.copy()
d["__type__"] = type(obj).__name__
if "custom_should_register" in d:
del d["custom_should_register"]
return d
def deserialize_dict_state_generic(cls):
def _deserialize(data):
h = cls()
h.__dict__.update(data)
return h
return _deserialize
def deserialize_hook_keyframe(data):
h = comfy.hooks.HookKeyframe(strength=data.get("strength", 1.0))
h.__dict__.update(data)
return h
registry.register("HookKeyframe", serialize_dict_state, deserialize_hook_keyframe)
def deserialize_hook_keyframe_group(data):
h = comfy.hooks.HookKeyframeGroup()
h.__dict__.update(data)
return h
registry.register(
"HookKeyframeGroup", serialize_dict_state, deserialize_hook_keyframe_group
)
def deserialize_hook(data):
h = comfy.hooks.Hook()
h.__dict__.update(data)
return h
registry.register("Hook", serialize_dict_state, deserialize_hook)
def deserialize_weight_hook(data):
h = comfy.hooks.WeightHook()
h.__dict__.update(data)
return h
registry.register("WeightHook", serialize_dict_state, deserialize_weight_hook)
def serialize_set(obj):
return {"__set__": list(obj)}
def deserialize_set(data):
return set(data["__set__"])
registry.register("set", serialize_set, deserialize_set)
try:
from comfy.weight_adapter.lora import LoRAAdapter
def serialize_lora(obj):
return {"weights": {}, "loaded_keys": list(obj.loaded_keys)}
def deserialize_lora(data):
return LoRAAdapter(set(data["loaded_keys"]), data["weights"])
registry.register("LoRAAdapter", serialize_lora, deserialize_lora)
except Exception:
pass
try:
from comfy.hooks import _HookRef
import uuid
def serialize_hook_ref(obj):
return {
"__hook_ref__": True,
"id": getattr(obj, "_pyisolate_id", str(uuid.uuid4())),
}
def deserialize_hook_ref(data):
h = _HookRef()
h._pyisolate_id = data.get("id", str(uuid.uuid4()))
return h
registry.register("_HookRef", serialize_hook_ref, deserialize_hook_ref)
except ImportError:
pass
except Exception as e:
logger.warning(f"Failed to register _HookRef: {e}")
try:
register_hooks_serializers()
except Exception as e:
logger.error(f"Failed to initialize hook serializers: {e}")

View File

@ -0,0 +1,253 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations
import asyncio
import logging
from typing import Any
from comfy.isolation.proxies.base import (
BaseProxy,
BaseRegistry,
detach_if_grad,
get_thread_loop,
run_coro_in_new_loop,
)
logger = logging.getLogger(__name__)
def _prefer_device(*tensors: Any) -> Any:
try:
import torch
except Exception:
return None
for t in tensors:
if isinstance(t, torch.Tensor) and t.is_cuda:
return t.device
for t in tensors:
if isinstance(t, torch.Tensor):
return t.device
return None
def _to_device(obj: Any, device: Any) -> Any:
try:
import torch
except Exception:
return obj
if device is None:
return obj
if isinstance(obj, torch.Tensor):
if obj.device != device:
return obj.to(device)
return obj
if isinstance(obj, (list, tuple)):
converted = [_to_device(x, device) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_device(v, device) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling"
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.calculate_input(sigma, noise))
async def calculate_denoised(
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.calculate_denoised(sigma, model_output, model_input)
)
async def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
)
async def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
async def timestep(self, instance_id: str, sigma: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.timestep(sigma)
async def sigma(self, instance_id: str, timestep: Any) -> Any:
sampling = self._get_instance(instance_id)
return sampling.sigma(timestep)
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
sampling = self._get_instance(instance_id)
return sampling.percent_to_sigma(percent)
async def get_sigma_min(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_min)
async def get_sigma_max(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_max)
async def get_sigma_data(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigma_data)
async def get_sigmas(self, instance_id: str) -> Any:
sampling = self._get_instance(instance_id)
return detach_if_grad(sampling.sigmas)
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
sampling = self._get_instance(instance_id)
sampling.set_sigmas(sigmas)
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
_registry_class = ModelSamplingRegistry
__module__ = "comfy.isolation.model_sampling_proxy"
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
)
else:
registry = ModelSamplingRegistry()
class _LocalCaller:
def calculate_input(
self, instance_id: str, sigma: Any, noise: Any
) -> Any:
return registry.calculate_input(instance_id, sigma, noise)
def calculate_denoised(
self,
instance_id: str,
sigma: Any,
model_output: Any,
model_input: Any,
) -> Any:
return registry.calculate_denoised(
instance_id, sigma, model_output, model_input
)
def noise_scaling(
self,
instance_id: str,
sigma: Any,
noise: Any,
latent_image: Any,
max_denoise: bool = False,
) -> Any:
return registry.noise_scaling(
instance_id, sigma, noise, latent_image, max_denoise
)
def inverse_noise_scaling(
self, instance_id: str, sigma: Any, latent: Any
) -> Any:
return registry.inverse_noise_scaling(
instance_id, sigma, latent
)
def timestep(self, instance_id: str, sigma: Any) -> Any:
return registry.timestep(instance_id, sigma)
def sigma(self, instance_id: str, timestep: Any) -> Any:
return registry.sigma(instance_id, timestep)
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
return registry.percent_to_sigma(instance_id, percent)
def get_sigma_min(self, instance_id: str) -> Any:
return registry.get_sigma_min(instance_id)
def get_sigma_max(self, instance_id: str) -> Any:
return registry.get_sigma_max(instance_id)
def get_sigma_data(self, instance_id: str) -> Any:
return registry.get_sigma_data(instance_id)
def get_sigmas(self, instance_id: str) -> Any:
return registry.get_sigmas(instance_id)
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
return registry.set_sigmas(instance_id, sigmas)
self._rpc_caller = _LocalCaller()
return self._rpc_caller
def _call(self, method_name: str, *args: Any) -> Any:
rpc = self._get_rpc()
method = getattr(rpc, method_name)
result = method(self._instance_id, *args)
if asyncio.iscoroutine(result):
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(result)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(result)
return result
@property
def sigma_min(self) -> Any:
return self._call("get_sigma_min")
@property
def sigma_max(self) -> Any:
return self._call("get_sigma_max")
@property
def sigma_data(self) -> Any:
return self._call("get_sigma_data")
@property
def sigmas(self) -> Any:
return self._call("get_sigmas")
def calculate_input(self, sigma: Any, noise: Any) -> Any:
return self._call("calculate_input", sigma, noise)
def calculate_denoised(
self, sigma: Any, model_output: Any, model_input: Any
) -> Any:
return self._call("calculate_denoised", sigma, model_output, model_input)
def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent)
def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma)
def sigma(self, timestep: Any) -> Any:
return self._call("sigma", timestep)
def percent_to_sigma(self, percent: float) -> Any:
return self._call("percent_to_sigma", percent)
def set_sigmas(self, sigmas: Any) -> None:
return self._call("set_sigmas", sigmas)

View File

@ -0,0 +1,214 @@
# pylint: disable=attribute-defined-outside-init
import logging
from typing import Any
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
BaseRegistry,
detach_if_grad,
)
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy, ModelPatcherRegistry
logger = logging.getLogger(__name__)
class FirstStageModelRegistry(BaseRegistry[Any]):
_type_prefix = "first_stage_model"
async def get_property(self, instance_id: str, name: str) -> Any:
obj = self._get_instance(instance_id)
return getattr(obj, name)
async def has_property(self, instance_id: str, name: str) -> bool:
obj = self._get_instance(instance_id)
return hasattr(obj, name)
class FirstStageModelProxy(BaseProxy[FirstStageModelRegistry]):
_registry_class = FirstStageModelRegistry
__module__ = "comfy.ldm.models.autoencoder"
def __getattr__(self, name: str) -> Any:
try:
return self._call_rpc("get_property", name)
except Exception as e:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
) from e
def __repr__(self) -> str:
return f"<FirstStageModelProxy {self._instance_id}>"
class VAERegistry(BaseRegistry[Any]):
_type_prefix = "vae"
async def get_patcher_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return ModelPatcherRegistry().register(vae.patcher)
async def get_first_stage_model_id(self, instance_id: str) -> str:
vae = self._get_instance(instance_id)
return FirstStageModelRegistry().register(vae.first_stage_model)
async def encode(self, instance_id: str, pixels: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).encode(pixels))
async def encode_tiled(
self,
instance_id: str,
pixels: Any,
tile_x: int = 512,
tile_y: int = 512,
overlap: int = 64,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).encode_tiled(
pixels, tile_x=tile_x, tile_y=tile_y, overlap=overlap
)
)
async def decode(self, instance_id: str, samples: Any, **kwargs: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).decode(samples, **kwargs))
async def decode_tiled(
self,
instance_id: str,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return detach_if_grad(
self._get_instance(instance_id).decode_tiled(
samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap, **kwargs
)
)
async def get_property(self, instance_id: str, name: str) -> Any:
return getattr(self._get_instance(instance_id), name)
async def memory_used_encode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_encode(shape, dtype)
async def memory_used_decode(self, instance_id: str, shape: Any, dtype: Any) -> int:
return self._get_instance(instance_id).memory_used_decode(shape, dtype)
async def process_input(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_input(image))
async def process_output(self, instance_id: str, image: Any) -> Any:
return detach_if_grad(self._get_instance(instance_id).process_output(image))
class VAEProxy(BaseProxy[VAERegistry]):
_registry_class = VAERegistry
__module__ = "comfy.sd"
@property
def patcher(self) -> ModelPatcherProxy:
if not hasattr(self, "_patcher_proxy"):
patcher_id = self._call_rpc("get_patcher_id")
self._patcher_proxy = ModelPatcherProxy(patcher_id, manage_lifecycle=False)
return self._patcher_proxy
@property
def first_stage_model(self) -> FirstStageModelProxy:
if not hasattr(self, "_first_stage_model_proxy"):
fsm_id = self._call_rpc("get_first_stage_model_id")
self._first_stage_model_proxy = FirstStageModelProxy(
fsm_id, manage_lifecycle=False
)
return self._first_stage_model_proxy
@property
def vae_dtype(self) -> Any:
return self._get_property("vae_dtype")
def encode(self, pixels: Any) -> Any:
return self._call_rpc("encode", pixels)
def encode_tiled(
self, pixels: Any, tile_x: int = 512, tile_y: int = 512, overlap: int = 64
) -> Any:
return self._call_rpc("encode_tiled", pixels, tile_x, tile_y, overlap)
def decode(self, samples: Any, **kwargs: Any) -> Any:
return self._call_rpc("decode", samples, **kwargs)
def decode_tiled(
self,
samples: Any,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 16,
**kwargs: Any,
) -> Any:
return self._call_rpc(
"decode_tiled", samples, tile_x, tile_y, overlap, **kwargs
)
def get_sd(self) -> Any:
return self._call_rpc("get_sd")
def _get_property(self, name: str) -> Any:
return self._call_rpc("get_property", name)
@property
def latent_dim(self) -> int:
return self._get_property("latent_dim")
@property
def latent_channels(self) -> int:
return self._get_property("latent_channels")
@property
def downscale_ratio(self) -> Any:
return self._get_property("downscale_ratio")
@property
def upscale_ratio(self) -> Any:
return self._get_property("upscale_ratio")
@property
def output_channels(self) -> int:
return self._get_property("output_channels")
@property
def check_not_vide(self) -> bool:
return self._get_property("not_video")
@property
def device(self) -> Any:
return self._get_property("device")
@property
def working_dtypes(self) -> Any:
return self._get_property("working_dtypes")
@property
def disable_offload(self) -> bool:
return self._get_property("disable_offload")
@property
def size(self) -> Any:
return self._get_property("size")
def memory_used_encode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_encode", shape, dtype)
def memory_used_decode(self, shape: Any, dtype: Any) -> int:
return self._call_rpc("memory_used_decode", shape, dtype)
def process_input(self, image: Any) -> Any:
return self._call_rpc("process_input", image)
def process_output(self, image: Any) -> Any:
return self._call_rpc("process_output", image)
if not IS_CHILD_PROCESS:
_VAE_REGISTRY_SINGLETON = VAERegistry()
_FIRST_STAGE_MODEL_REGISTRY_SINGLETON = FirstStageModelRegistry()