ComfyUI/comfy/isolation/model_patcher_proxy.py
John Pollock 26edd5663d
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
isolation+dynamicvram: stabilize ModelPatcher RPC path, add diagnostics; known process_latent_in timeout remains
- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers
- improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups
- add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing
- preserve isolation-first lifecycle handling and boundary cleanup sequencing
- validate isolated workflows: most targeted runs now complete under
  --use-sage-attention --use-process-isolation --disable-cuda-malloc

Known issue (reproducible):
- isolation_99_full_iso_stack still times out at SamplerCustom_ISO path
- failure is explicit RPC timeout:
  ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000)
- this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
2026-03-04 10:41:33 -06:00

862 lines
29 KiB
Python

# pylint: disable=bare-except,consider-using-from-import,import-outside-toplevel,protected-access
# RPC proxy for ModelPatcher (parent process)
from __future__ import annotations
import logging
from typing import Any, Optional, List, Set, Dict, Callable
from comfy.isolation.proxies.base import (
IS_CHILD_PROCESS,
BaseProxy,
)
from comfy.isolation.model_patcher_proxy_registry import (
ModelPatcherRegistry,
AutoPatcherEjector,
)
logger = logging.getLogger(__name__)
class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
_registry_class = ModelPatcherRegistry
__module__ = "comfy.model_patcher"
_APPLY_MODEL_GUARD_PADDING_BYTES = 32 * 1024 * 1024
def _get_rpc(self) -> Any:
if self._rpc_caller is None:
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
rpc = get_child_rpc_instance()
if rpc is not None:
self._rpc_caller = rpc.create_caller(
self._registry_class, self._registry_class.get_remote_id()
)
else:
self._rpc_caller = self._registry
return self._rpc_caller
def get_all_callbacks(self, call_type: str = None) -> Any:
return self._call_rpc("get_all_callbacks", call_type)
def get_all_wrappers(self, wrapper_type: str = None) -> Any:
return self._call_rpc("get_all_wrappers", wrapper_type)
def _load_list(self, *args, **kwargs) -> Any:
return self._call_rpc("load_list_internal", *args, **kwargs)
def prepare_hook_patches_current_keyframe(
self, t: Any, hook_group: Any, model_options: Any
) -> None:
self._call_rpc(
"prepare_hook_patches_current_keyframe", t, hook_group, model_options
)
def add_hook_patches(
self,
hook: Any,
patches: Any,
strength_patch: float = 1.0,
strength_model: float = 1.0,
) -> None:
self._call_rpc(
"add_hook_patches", hook, patches, strength_patch, strength_model
)
def clear_cached_hook_weights(self) -> None:
self._call_rpc("clear_cached_hook_weights")
def get_combined_hook_patches(self, hooks: Any) -> Any:
return self._call_rpc("get_combined_hook_patches", hooks)
def get_additional_models_with_key(self, key: str) -> Any:
return self._call_rpc("get_additional_models_with_key", key)
@property
def object_patches(self) -> Any:
return self._call_rpc("get_object_patches")
@property
def patches(self) -> Any:
res = self._call_rpc("get_patches")
if isinstance(res, dict):
new_res = {}
for k, v in res.items():
new_list = []
for item in v:
if isinstance(item, list):
new_list.append(tuple(item))
else:
new_list.append(item)
new_res[k] = new_list
return new_res
return res
@property
def pinned(self) -> Set:
val = self._call_rpc("get_patcher_attr", "pinned")
return set(val) if val is not None else set()
@property
def hook_patches(self) -> Dict:
val = self._call_rpc("get_patcher_attr", "hook_patches")
if val is None:
return {}
try:
from comfy.hooks import _HookRef
import json
new_val = {}
for k, v in val.items():
if isinstance(k, str):
if k.startswith("PYISOLATE_HOOKREF:"):
ref_id = k.split(":", 1)[1]
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
elif k.startswith("__pyisolate_key__"):
try:
json_str = k[len("__pyisolate_key__") :]
data = json.loads(json_str)
ref_id = None
if isinstance(data, list):
for item in data:
if (
isinstance(item, list)
and len(item) == 2
and item[0] == "id"
):
ref_id = item[1]
break
if ref_id:
h = _HookRef()
h._pyisolate_id = ref_id
new_val[h] = v
else:
new_val[k] = v
except Exception:
new_val[k] = v
else:
new_val[k] = v
else:
new_val[k] = v
return new_val
except ImportError:
return val
def set_hook_mode(self, hook_mode: Any) -> None:
self._call_rpc("set_hook_mode", hook_mode)
def register_all_hook_patches(
self,
hooks: Any,
target_dict: Any,
model_options: Any = None,
registered: Any = None,
) -> None:
self._call_rpc(
"register_all_hook_patches", hooks, target_dict, model_options, registered
)
def is_clone(self, other: Any) -> bool:
if isinstance(other, ModelPatcherProxy):
return self._call_rpc("is_clone_by_id", other._instance_id)
return False
def clone(self) -> ModelPatcherProxy:
new_id = self._call_rpc("clone")
return ModelPatcherProxy(
new_id, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
def clone_has_same_weights(self, clone: Any) -> bool:
if isinstance(clone, ModelPatcherProxy):
return self._call_rpc("clone_has_same_weights_by_id", clone._instance_id)
if not IS_CHILD_PROCESS:
return self._call_rpc("is_clone", clone)
return False
def get_model_object(self, name: str) -> Any:
return self._call_rpc("get_model_object", name)
@property
def model_options(self) -> dict:
data = self._call_rpc("get_model_options")
import json
def _decode_keys(obj):
if isinstance(obj, dict):
new_d = {}
for k, v in obj.items():
if isinstance(k, str) and k.startswith("__pyisolate_key__"):
try:
json_str = k[17:]
val = json.loads(json_str)
if isinstance(val, list):
val = tuple(val)
new_d[val] = _decode_keys(v)
except:
new_d[k] = _decode_keys(v)
else:
new_d[k] = _decode_keys(v)
return new_d
if isinstance(obj, list):
return [_decode_keys(x) for x in obj]
return obj
return _decode_keys(data)
@model_options.setter
def model_options(self, value: dict) -> None:
self._call_rpc("set_model_options", value)
def apply_hooks(self, hooks: Any) -> Any:
return self._call_rpc("apply_hooks", hooks)
def prepare_state(self, timestep: Any) -> Any:
return self._call_rpc("prepare_state", timestep)
def restore_hook_patches(self) -> None:
self._call_rpc("restore_hook_patches")
def unpatch_hooks(self, whitelist_keys_set: Optional[Set[str]] = None) -> None:
self._call_rpc("unpatch_hooks", whitelist_keys_set)
def model_patches_to(self, device: Any) -> Any:
return self._call_rpc("model_patches_to", device)
def partially_load(
self, device: Any, extra_memory: Any, force_patch_weights: bool = False
) -> Any:
return self._call_rpc(
"partially_load", device, extra_memory, force_patch_weights
)
def partially_unload(
self, device_to: Any, memory_to_free: int = 0, force_patch_weights: bool = False
) -> int:
return self._call_rpc(
"partially_unload", device_to, memory_to_free, force_patch_weights
)
def load(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
force_patch_weights: bool = False,
full_load: bool = False,
) -> None:
self._call_rpc(
"load", device_to, lowvram_model_memory, force_patch_weights, full_load
)
def patch_model(
self,
device_to: Any = None,
lowvram_model_memory: int = 0,
load_weights: bool = True,
force_patch_weights: bool = False,
) -> Any:
self._call_rpc(
"patch_model",
device_to,
lowvram_model_memory,
load_weights,
force_patch_weights,
)
return self
def unpatch_model(
self, device_to: Any = None, unpatch_weights: bool = True
) -> None:
self._call_rpc("unpatch_model", device_to, unpatch_weights)
def detach(self, unpatch_all: bool = True) -> Any:
self._call_rpc("detach", unpatch_all)
return self.model
def _cpu_tensor_bytes(self, obj: Any) -> int:
import torch
if isinstance(obj, torch.Tensor):
if obj.device.type == "cpu":
return obj.nbytes
return 0
if isinstance(obj, dict):
return sum(self._cpu_tensor_bytes(v) for v in obj.values())
if isinstance(obj, (list, tuple)):
return sum(self._cpu_tensor_bytes(v) for v in obj)
return 0
def _ensure_apply_model_headroom(self, required_bytes: int) -> bool:
if required_bytes <= 0:
return True
import torch
import comfy.model_management as model_management
target_raw = self.load_device
try:
if isinstance(target_raw, torch.device):
target = target_raw
elif isinstance(target_raw, str):
target = torch.device(target_raw)
elif isinstance(target_raw, int):
target = torch.device(f"cuda:{target_raw}")
else:
target = torch.device(target_raw)
except Exception:
return True
if target.type != "cuda":
return True
required = required_bytes + self._APPLY_MODEL_GUARD_PADDING_BYTES
if model_management.get_free_memory(target) >= required:
return True
model_management.cleanup_models_gc()
model_management.cleanup_models()
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.free_memory(required, target, for_dynamic=True)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
# Escalate to non-dynamic unloading before dispatching CUDA transfer.
model_management.free_memory(required, target, for_dynamic=False)
model_management.soft_empty_cache()
if model_management.get_free_memory(target) < required:
model_management.load_models_gpu(
[self],
minimum_memory_required=required,
)
return model_management.get_free_memory(target) >= required
def apply_model(self, *args, **kwargs) -> Any:
import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes)
def _to_cuda(obj: Any) -> Any:
if isinstance(obj, torch.Tensor) and obj.device.type == "cpu":
return obj.to("cuda")
if isinstance(obj, dict):
return {k: _to_cuda(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cuda(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cuda(v) for v in obj)
return obj
try:
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
except torch.OutOfMemoryError:
self._ensure_apply_model_headroom(required_bytes)
args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs)
out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix)
return dict.fromkeys(keys, None)
def add_patches(self, *args: Any, **kwargs: Any) -> Any:
res = self._call_rpc("add_patches", *args, **kwargs)
if isinstance(res, list):
return [tuple(x) if isinstance(x, list) else x for x in res]
return res
def get_key_patches(self, filter_prefix: Optional[str] = None) -> Any:
return self._call_rpc("get_key_patches", filter_prefix)
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
self._call_rpc("patch_weight_to_device", key, device_to, inplace_update)
def pin_weight_to_device(self, key):
self._call_rpc("pin_weight_to_device", key)
def unpin_weight(self, key):
self._call_rpc("unpin_weight", key)
def unpin_all_weights(self):
self._call_rpc("unpin_all_weights")
def calculate_weight(self, patches, weight, key, intermediate_dtype=None):
return self._call_rpc(
"calculate_weight", patches, weight, key, intermediate_dtype
)
def inject_model(self) -> None:
self._call_rpc("inject_model")
def eject_model(self) -> None:
self._call_rpc("eject_model")
def use_ejected(self, skip_and_inject_on_exit_only: bool = False) -> Any:
return AutoPatcherEjector(
self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only
)
@property
def is_injected(self) -> bool:
return self._call_rpc("get_is_injected")
@property
def skip_injection(self) -> bool:
return self._call_rpc("get_skip_injection")
@skip_injection.setter
def skip_injection(self, value: bool) -> None:
self._call_rpc("set_skip_injection", value)
def clean_hooks(self) -> None:
self._call_rpc("clean_hooks")
def pre_run(self) -> None:
self._call_rpc("pre_run")
def cleanup(self) -> None:
try:
self._call_rpc("cleanup")
except Exception:
logger.debug(
"ModelPatcherProxy cleanup RPC failed for %s",
self._instance_id,
exc_info=True,
)
finally:
super().cleanup()
@property
def model(self) -> _InnerModelProxy:
return _InnerModelProxy(self)
def __getattr__(self, name: str) -> Any:
_whitelisted_attrs = {
"hook_patches_backup",
"hook_backup",
"cached_hook_patches",
"current_hooks",
"forced_hooks",
"is_clip",
"patches_uuid",
"pinned",
"attachments",
"additional_models",
"injections",
"hook_patches",
"model_lowvram",
"model_loaded_weight_memory",
"backup",
"object_patches_backup",
"weight_wrapper_patches",
"weight_inplace_update",
"force_cast_weights",
}
if name in _whitelisted_attrs:
return self._call_rpc("get_patcher_attr", name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def load_lora(
self,
lora_path: str,
strength_model: float,
clip: Optional[Any] = None,
strength_clip: float = 1.0,
) -> tuple:
clip_id = None
if clip is not None:
clip_id = getattr(clip, "_instance_id", getattr(clip, "_clip_id", None))
result = self._call_rpc(
"load_lora", lora_path, strength_model, clip_id, strength_clip
)
new_model = None
if result.get("model_id"):
new_model = ModelPatcherProxy(
result["model_id"],
self._registry,
manage_lifecycle=not IS_CHILD_PROCESS,
)
new_clip = None
if result.get("clip_id"):
from comfy.isolation.clip_proxy import CLIPProxy
new_clip = CLIPProxy(result["clip_id"])
return (new_model, new_clip)
@property
def load_device(self) -> Any:
return self._call_rpc("get_load_device")
@property
def offload_device(self) -> Any:
return self._call_rpc("get_offload_device")
@property
def device(self) -> Any:
return self.load_device
def current_loaded_device(self) -> Any:
return self._call_rpc("current_loaded_device")
@property
def size(self) -> int:
return self._call_rpc("get_size")
def model_size(self) -> Any:
return self._call_rpc("model_size")
def loaded_size(self) -> Any:
return self._call_rpc("loaded_size")
def get_ram_usage(self) -> int:
return self._call_rpc("get_ram_usage")
def lowvram_patch_counter(self) -> int:
return self._call_rpc("lowvram_patch_counter")
def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic"))
def get_free_memory(self, device: Any) -> Any:
return self._call_rpc("get_free_memory", device)
def partially_unload_ram(self, ram_to_unload: int) -> Any:
return self._call_rpc("partially_unload_ram", ram_to_unload)
def model_dtype(self) -> Any:
res = self._call_rpc("model_dtype")
if isinstance(res, str) and res.startswith("torch."):
try:
import torch
attr = res.split(".")[-1]
if hasattr(torch, attr):
return getattr(torch, attr)
except ImportError:
pass
return res
@property
def hook_mode(self) -> Any:
return self._call_rpc("get_hook_mode")
@hook_mode.setter
def hook_mode(self, value: Any) -> None:
self._call_rpc("set_hook_mode", value)
def set_model_sampler_cfg_function(
self, sampler_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_cfg_function",
sampler_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_post_cfg_function(
self, post_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_post_cfg_function",
post_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_pre_cfg_function(
self, pre_cfg_function: Any, disable_cfg1_optimization: bool = False
) -> None:
self._call_rpc(
"set_model_sampler_pre_cfg_function",
pre_cfg_function,
disable_cfg1_optimization,
)
def set_model_sampler_calc_cond_batch_function(self, fn: Any) -> None:
self._call_rpc("set_model_sampler_calc_cond_batch_function", fn)
def set_model_unet_function_wrapper(self, unet_wrapper_function: Any) -> None:
self._call_rpc("set_model_unet_function_wrapper", unet_wrapper_function)
def set_model_denoise_mask_function(self, denoise_mask_function: Any) -> None:
self._call_rpc("set_model_denoise_mask_function", denoise_mask_function)
def set_model_patch(self, patch: Any, name: str) -> None:
self._call_rpc("set_model_patch", patch, name)
def set_model_patch_replace(
self,
patch: Any,
name: str,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self._call_rpc(
"set_model_patch_replace",
patch,
name,
block_name,
number,
transformer_index,
)
def set_model_attn1_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn1", block_name, number, transformer_index
)
def set_model_attn2_replace(
self,
patch: Any,
block_name: str,
number: int,
transformer_index: Optional[int] = None,
) -> None:
self.set_model_patch_replace(
patch, "attn2", block_name, number, transformer_index
)
def set_model_attn1_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn1_output_patch")
def set_model_attn2_output_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch: Any) -> None:
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "output_block_patch")
def set_model_emb_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "emb_patch")
def set_model_forward_timestep_embed_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch: Any) -> None:
self.set_model_patch(patch, "post_input")
def set_model_rope_options(
self,
scale_x=1.0,
shift_x=0.0,
scale_y=1.0,
shift_y=0.0,
scale_t=1.0,
shift_t=0.0,
**kwargs: Any,
) -> None:
options = {
"scale_x": scale_x,
"shift_x": shift_x,
"scale_y": scale_y,
"shift_y": shift_y,
"scale_t": scale_t,
"shift_t": shift_t,
}
options.update(kwargs)
self._call_rpc("set_model_rope_options", options)
def set_model_compute_dtype(self, dtype: Any) -> None:
self._call_rpc("set_model_compute_dtype", dtype)
def add_object_patch(self, name: str, obj: Any) -> None:
self._call_rpc("add_object_patch", name, obj)
def add_weight_wrapper(self, name: str, function: Any) -> None:
self._call_rpc("add_weight_wrapper", name, function)
def add_wrapper_with_key(self, wrapper_type: Any, key: str, fn: Any) -> None:
self._call_rpc("add_wrapper_with_key", wrapper_type, key, fn)
def add_wrapper(self, wrapper_type: str, wrapper: Callable) -> None:
self.add_wrapper_with_key(wrapper_type, None, wrapper)
def remove_wrappers_with_key(self, wrapper_type: str, key: str) -> None:
self._call_rpc("remove_wrappers_with_key", wrapper_type, key)
@property
def wrappers(self) -> Any:
return self._call_rpc("get_wrappers")
def add_callback_with_key(self, call_type: str, key: str, callback: Any) -> None:
self._call_rpc("add_callback_with_key", call_type, key, callback)
def add_callback(self, call_type: str, callback: Any) -> None:
self.add_callback_with_key(call_type, None, callback)
def remove_callbacks_with_key(self, call_type: str, key: str) -> None:
self._call_rpc("remove_callbacks_with_key", call_type, key)
@property
def callbacks(self) -> Any:
return self._call_rpc("get_callbacks")
def set_attachments(self, key: str, attachment: Any) -> None:
self._call_rpc("set_attachments", key, attachment)
def get_attachment(self, key: str) -> Any:
return self._call_rpc("get_attachment", key)
def remove_attachments(self, key: str) -> None:
self._call_rpc("remove_attachments", key)
def set_injections(self, key: str, injections: Any) -> None:
self._call_rpc("set_injections", key, injections)
def get_injections(self, key: str) -> Any:
return self._call_rpc("get_injections", key)
def remove_injections(self, key: str) -> None:
self._call_rpc("remove_injections", key)
def set_additional_models(self, key: str, models: Any) -> None:
ids = [m._instance_id for m in models]
self._call_rpc("set_additional_models", key, ids)
def remove_additional_models(self, key: str) -> None:
self._call_rpc("remove_additional_models", key)
def get_nested_additional_models(self) -> Any:
return self._call_rpc("get_nested_additional_models")
def get_additional_models(self) -> List[ModelPatcherProxy]:
ids = self._call_rpc("get_additional_models")
return [
ModelPatcherProxy(
mid, self._registry, manage_lifecycle=not IS_CHILD_PROCESS
)
for mid in ids
]
def model_patches_models(self) -> Any:
return self._call_rpc("model_patches_models")
@property
def parent(self) -> Any:
return self._call_rpc("get_parent")
class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy):
self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(name)
if name in (
"model_config",
"latent_format",
"model_type",
"current_weight_patches_uuid",
):
return self._parent._call_rpc("get_inner_model_attr", name)
if name == "load_device":
return self._parent._call_rpc("get_inner_model_attr", "load_device")
if name == "device":
return self._parent._call_rpc("get_inner_model_attr", "device")
if name == "current_patcher":
return ModelPatcherProxy(
self._parent._instance_id,
self._parent._registry,
manage_lifecycle=False,
)
if name == "model_sampling":
if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k
)
if name == "extra_conds":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds", a, k
)
if name == "memory_required":
return lambda *a, **k: self._parent._call_rpc(
"inner_model_memory_required", a, k
)
if name == "apply_model":
# Delegate to parent's method to get the CPU->CUDA optimization
return self._parent.apply_model
if name == "process_latent_in":
return lambda *a, **k: self._parent._call_rpc("process_latent_in", a, k)
if name == "process_latent_out":
return lambda *a, **k: self._parent._call_rpc("process_latent_out", a, k)
if name == "scale_latent_inpaint":
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
if name == "diffusion_model":
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
raise AttributeError(f"'{name}' not supported on isolated InnerModel")