mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
- harden isolation ModelPatcher proxy/registry behavior for DynamicVRAM-backed patchers - improve serializer/adapter boundaries (device/dtype/model refs) to reduce pre-inference lockups - add structured ISO registry/modelsampling telemetry and explicit RPC timeout surfacing - preserve isolation-first lifecycle handling and boundary cleanup sequencing - validate isolated workflows: most targeted runs now complete under --use-sage-attention --use-process-isolation --disable-cuda-malloc Known issue (reproducible): - isolation_99_full_iso_stack still times out at SamplerCustom_ISO path - failure is explicit RPC timeout: ModelPatcherProxy.process_latent_in(instance_id=model_0, timeout_ms=120000) - this indicates the remaining stall is on process_latent_in RPC path, not generic startup/manager fetch
381 lines
13 KiB
Python
381 lines
13 KiB
Python
# pylint: disable=import-outside-toplevel
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
from comfy.isolation.proxies.base import (
|
|
BaseProxy,
|
|
BaseRegistry,
|
|
detach_if_grad,
|
|
get_thread_loop,
|
|
run_coro_in_new_loop,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _describe_value(obj: Any) -> str:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
torch = None
|
|
try:
|
|
if torch is not None and isinstance(obj, torch.Tensor):
|
|
return (
|
|
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
|
|
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
|
|
)
|
|
except Exception:
|
|
pass
|
|
return "%s(id=%s)" % (type(obj).__name__, id(obj))
|
|
|
|
|
|
def _prefer_device(*tensors: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return None
|
|
for t in tensors:
|
|
if isinstance(t, torch.Tensor) and t.is_cuda:
|
|
return t.device
|
|
for t in tensors:
|
|
if isinstance(t, torch.Tensor):
|
|
return t.device
|
|
return None
|
|
|
|
|
|
def _to_device(obj: Any, device: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return obj
|
|
if device is None:
|
|
return obj
|
|
if isinstance(obj, torch.Tensor):
|
|
if obj.device != device:
|
|
return obj.to(device)
|
|
return obj
|
|
if isinstance(obj, (list, tuple)):
|
|
converted = [_to_device(x, device) for x in obj]
|
|
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
|
if isinstance(obj, dict):
|
|
return {k: _to_device(v, device) for k, v in obj.items()}
|
|
return obj
|
|
|
|
|
|
def _to_cpu_for_rpc(obj: Any) -> Any:
|
|
try:
|
|
import torch
|
|
except Exception:
|
|
return obj
|
|
if isinstance(obj, torch.Tensor):
|
|
t = obj.detach() if obj.requires_grad else obj
|
|
if t.is_cuda:
|
|
return t.to("cpu")
|
|
return t
|
|
if isinstance(obj, (list, tuple)):
|
|
converted = [_to_cpu_for_rpc(x) for x in obj]
|
|
return type(obj)(converted) if isinstance(obj, tuple) else converted
|
|
if isinstance(obj, dict):
|
|
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
|
|
return obj
|
|
|
|
|
|
class ModelSamplingRegistry(BaseRegistry[Any]):
|
|
_type_prefix = "modelsampling"
|
|
|
|
async def calculate_input(self, instance_id: str, sigma: Any, noise: Any) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.calculate_input(sigma, noise))
|
|
|
|
async def calculate_denoised(
|
|
self, instance_id: str, sigma: Any, model_output: Any, model_input: Any
|
|
) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(
|
|
sampling.calculate_denoised(sigma, model_output, model_input)
|
|
)
|
|
|
|
async def noise_scaling(
|
|
self,
|
|
instance_id: str,
|
|
sigma: Any,
|
|
noise: Any,
|
|
latent_image: Any,
|
|
max_denoise: bool = False,
|
|
) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(
|
|
sampling.noise_scaling(sigma, noise, latent_image, max_denoise=max_denoise)
|
|
)
|
|
|
|
async def inverse_noise_scaling(
|
|
self, instance_id: str, sigma: Any, latent: Any
|
|
) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.inverse_noise_scaling(sigma, latent))
|
|
|
|
async def timestep(self, instance_id: str, sigma: Any) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return sampling.timestep(sigma)
|
|
|
|
async def sigma(self, instance_id: str, timestep: Any) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return sampling.sigma(timestep)
|
|
|
|
async def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return sampling.percent_to_sigma(percent)
|
|
|
|
async def get_sigma_min(self, instance_id: str) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.sigma_min)
|
|
|
|
async def get_sigma_max(self, instance_id: str) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.sigma_max)
|
|
|
|
async def get_sigma_data(self, instance_id: str) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.sigma_data)
|
|
|
|
async def get_sigmas(self, instance_id: str) -> Any:
|
|
sampling = self._get_instance(instance_id)
|
|
return detach_if_grad(sampling.sigmas)
|
|
|
|
async def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
|
sampling = self._get_instance(instance_id)
|
|
sampling.set_sigmas(sigmas)
|
|
|
|
|
|
class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
|
|
_registry_class = ModelSamplingRegistry
|
|
__module__ = "comfy.isolation.model_sampling_proxy"
|
|
|
|
def _get_rpc(self) -> Any:
|
|
if self._rpc_caller is None:
|
|
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
|
|
|
rpc = get_child_rpc_instance()
|
|
if rpc is not None:
|
|
self._rpc_caller = rpc.create_caller(
|
|
ModelSamplingRegistry, ModelSamplingRegistry.get_remote_id()
|
|
)
|
|
else:
|
|
registry = ModelSamplingRegistry()
|
|
|
|
class _LocalCaller:
|
|
def calculate_input(
|
|
self, instance_id: str, sigma: Any, noise: Any
|
|
) -> Any:
|
|
return registry.calculate_input(instance_id, sigma, noise)
|
|
|
|
def calculate_denoised(
|
|
self,
|
|
instance_id: str,
|
|
sigma: Any,
|
|
model_output: Any,
|
|
model_input: Any,
|
|
) -> Any:
|
|
return registry.calculate_denoised(
|
|
instance_id, sigma, model_output, model_input
|
|
)
|
|
|
|
def noise_scaling(
|
|
self,
|
|
instance_id: str,
|
|
sigma: Any,
|
|
noise: Any,
|
|
latent_image: Any,
|
|
max_denoise: bool = False,
|
|
) -> Any:
|
|
return registry.noise_scaling(
|
|
instance_id, sigma, noise, latent_image, max_denoise
|
|
)
|
|
|
|
def inverse_noise_scaling(
|
|
self, instance_id: str, sigma: Any, latent: Any
|
|
) -> Any:
|
|
return registry.inverse_noise_scaling(
|
|
instance_id, sigma, latent
|
|
)
|
|
|
|
def timestep(self, instance_id: str, sigma: Any) -> Any:
|
|
return registry.timestep(instance_id, sigma)
|
|
|
|
def sigma(self, instance_id: str, timestep: Any) -> Any:
|
|
return registry.sigma(instance_id, timestep)
|
|
|
|
def percent_to_sigma(self, instance_id: str, percent: float) -> Any:
|
|
return registry.percent_to_sigma(instance_id, percent)
|
|
|
|
def get_sigma_min(self, instance_id: str) -> Any:
|
|
return registry.get_sigma_min(instance_id)
|
|
|
|
def get_sigma_max(self, instance_id: str) -> Any:
|
|
return registry.get_sigma_max(instance_id)
|
|
|
|
def get_sigma_data(self, instance_id: str) -> Any:
|
|
return registry.get_sigma_data(instance_id)
|
|
|
|
def get_sigmas(self, instance_id: str) -> Any:
|
|
return registry.get_sigmas(instance_id)
|
|
|
|
def set_sigmas(self, instance_id: str, sigmas: Any) -> None:
|
|
return registry.set_sigmas(instance_id, sigmas)
|
|
|
|
self._rpc_caller = _LocalCaller()
|
|
return self._rpc_caller
|
|
|
|
def _call(self, method_name: str, *args: Any) -> Any:
|
|
print(
|
|
"ISO:modelsampling_call_enter method=%s instance_id=%s pid=%s"
|
|
% (method_name, self._instance_id, os.getpid()),
|
|
flush=True,
|
|
)
|
|
rpc = self._get_rpc()
|
|
method = getattr(rpc, method_name)
|
|
print(
|
|
"ISO:modelsampling_call_before_dispatch method=%s instance_id=%s pid=%s"
|
|
% (method_name, self._instance_id, os.getpid()),
|
|
flush=True,
|
|
)
|
|
result = method(self._instance_id, *args)
|
|
timeout_ms = self._rpc_timeout_ms()
|
|
start_epoch = time.time()
|
|
start_perf = time.perf_counter()
|
|
thread_id = threading.get_ident()
|
|
call_id = "%s:%s:%s:%.6f" % (
|
|
self._instance_id,
|
|
method_name,
|
|
thread_id,
|
|
start_perf,
|
|
)
|
|
logger.debug(
|
|
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
|
|
method_name,
|
|
self._instance_id,
|
|
call_id,
|
|
start_epoch,
|
|
thread_id,
|
|
timeout_ms,
|
|
)
|
|
if asyncio.iscoroutine(result):
|
|
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
|
|
try:
|
|
asyncio.get_running_loop()
|
|
out = run_coro_in_new_loop(result)
|
|
except RuntimeError:
|
|
loop = get_thread_loop()
|
|
out = loop.run_until_complete(result)
|
|
else:
|
|
out = result
|
|
print(
|
|
"ISO:modelsampling_call_after_dispatch method=%s instance_id=%s pid=%s"
|
|
% (method_name, self._instance_id, os.getpid()),
|
|
flush=True,
|
|
)
|
|
logger.debug(
|
|
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
|
|
method_name,
|
|
self._instance_id,
|
|
call_id,
|
|
_describe_value(out),
|
|
)
|
|
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
|
|
logger.debug(
|
|
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
|
|
method_name,
|
|
self._instance_id,
|
|
call_id,
|
|
elapsed_ms,
|
|
thread_id,
|
|
)
|
|
logger.debug(
|
|
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
|
|
method_name,
|
|
self._instance_id,
|
|
call_id,
|
|
)
|
|
print(
|
|
"ISO:modelsampling_call_return method=%s instance_id=%s pid=%s"
|
|
% (method_name, self._instance_id, os.getpid()),
|
|
flush=True,
|
|
)
|
|
return out
|
|
|
|
@staticmethod
|
|
def _rpc_timeout_ms() -> int:
|
|
raw = os.environ.get(
|
|
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
|
|
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
|
|
)
|
|
try:
|
|
timeout_ms = int(raw)
|
|
except ValueError:
|
|
timeout_ms = 30000
|
|
return max(1, timeout_ms)
|
|
|
|
@property
|
|
def sigma_min(self) -> Any:
|
|
return self._call("get_sigma_min")
|
|
|
|
@property
|
|
def sigma_max(self) -> Any:
|
|
return self._call("get_sigma_max")
|
|
|
|
@property
|
|
def sigma_data(self) -> Any:
|
|
return self._call("get_sigma_data")
|
|
|
|
@property
|
|
def sigmas(self) -> Any:
|
|
return self._call("get_sigmas")
|
|
|
|
def calculate_input(self, sigma: Any, noise: Any) -> Any:
|
|
return self._call("calculate_input", sigma, noise)
|
|
|
|
def calculate_denoised(
|
|
self, sigma: Any, model_output: Any, model_input: Any
|
|
) -> Any:
|
|
return self._call("calculate_denoised", sigma, model_output, model_input)
|
|
|
|
def noise_scaling(
|
|
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
|
|
) -> Any:
|
|
preferred_device = _prefer_device(noise, latent_image)
|
|
out = self._call(
|
|
"noise_scaling",
|
|
_to_cpu_for_rpc(sigma),
|
|
_to_cpu_for_rpc(noise),
|
|
_to_cpu_for_rpc(latent_image),
|
|
max_denoise,
|
|
)
|
|
return _to_device(out, preferred_device)
|
|
|
|
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
|
|
preferred_device = _prefer_device(latent)
|
|
out = self._call(
|
|
"inverse_noise_scaling",
|
|
_to_cpu_for_rpc(sigma),
|
|
_to_cpu_for_rpc(latent),
|
|
)
|
|
return _to_device(out, preferred_device)
|
|
|
|
def timestep(self, sigma: Any) -> Any:
|
|
return self._call("timestep", sigma)
|
|
|
|
def sigma(self, timestep: Any) -> Any:
|
|
return self._call("sigma", timestep)
|
|
|
|
def percent_to_sigma(self, percent: float) -> Any:
|
|
return self._call("percent_to_sigma", percent)
|
|
|
|
def set_sigmas(self, sigmas: Any) -> None:
|
|
return self._call("set_sigmas", sigmas)
|