fix(isolation-lifecycle): execution/model ejection parity + fenced sampler device handling

add pyisolate==0.9.1 to requirements.txt
This commit is contained in:
John Pollock 2026-02-27 13:07:23 -06:00
parent a1c3124821
commit 3c8ba051b6
8 changed files with 229 additions and 56 deletions

View File

@ -14,6 +14,9 @@ if TYPE_CHECKING:
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.cli_args import args
import uuid
import os
from node_helpers import conditioning_set_values
# #######################################################################################################
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
class _HookRef:
pass
def __init__(self):
if _ISOLATION_HOOKREF_MODE:
self._pyisolate_id = str(uuid.uuid4())
def _ensure_pyisolate_id(self):
pyisolate_id = getattr(self, "_pyisolate_id", None)
if pyisolate_id is None:
pyisolate_id = str(uuid.uuid4())
self._pyisolate_id = pyisolate_id
return pyisolate_id
def __eq__(self, other):
if not _ISOLATION_HOOKREF_MODE:
return self is other
if not isinstance(other, _HookRef):
return False
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
def __hash__(self):
if not _ISOLATION_HOOKREF_MODE:
return id(self)
return hash(self._ensure_pyisolate_id())
def __str__(self):
if not _ISOLATION_HOOKREF_MODE:
return super().__str__()
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@ -168,6 +200,8 @@ class WeightHook(Hook):
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if self.weights is None:
self.weights = {}
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Clip:

View File

@ -1,4 +1,5 @@
import math
import os
from functools import partial
from scipy import integrate
@ -12,8 +13,8 @@ from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
from comfy.cli_args import args
from comfy.utils import model_trange as trange
def append_zero(x):
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:
x = x.to(target_device)
s_in = s_in.to(target_device)
for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.

View File

@ -76,7 +76,6 @@ class ModelType(Enum):
FLUX = 8
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type):
@ -109,11 +108,17 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c):
pass
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config)
@ -974,10 +979,6 @@ class LTXV(BaseModel):
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -1030,10 +1031,6 @@ class LTXAV(BaseModel):
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1477,12 +1474,6 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@ -378,7 +378,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@ -570,7 +570,13 @@ class LoadedModel:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
# Model is dead if the weakref to model has been garbage collected
# This can happen with ModelPatcherProxy objects between isolated workflows
if self.model is None:
return True
if self.real_model is None:
return False
return self.real_model() is None
def use_more_memory(extra_memory, loaded_models, device):
@ -616,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
unloaded_model = []
can_unload = []
unloaded_models = []
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@ -624,6 +631,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
if can_unload and isolation_active:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
flush_tensor_keeper = None
if callable(flush_tensor_keeper):
flushed = flush_tensor_keeper()
if flushed > 0:
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
gc.collect()
for x in sorted(can_unload):
i = x[-1]
memory_to_free = 1e32
@ -645,7 +663,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
unloaded = current_loaded_models.pop(i)
model_obj = unloaded.model
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
unloaded_models.append(unloaded)
if len(unloaded_model) > 0:
soft_empty_cache()
@ -767,25 +791,28 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
dead_found = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
if current_loaded_models[i].is_dead():
dead_found = True
break
if do_gc:
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
def archive_model_dtypes(model):
@ -802,6 +829,11 @@ def cleanup_models():
for i in to_delete:
x = current_loaded_models.pop(i)
model_obj = getattr(x, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
del x
def dtype_size(dtype):

View File

@ -11,12 +11,14 @@ from functools import partial
import collections
import math
import logging
import os
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.utils
from comfy.cli_args import args
import scipy.stats
import numpy
@ -213,6 +215,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
out_conds = []
out_counts = []
# separate conds by matching hooks
@ -294,9 +297,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
if isolation_active:
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
input_x = torch.cat(input_x).to(target_device)
else:
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if isolation_active:
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
else:
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
@ -327,9 +338,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
out_t = output[o]
mult_t = mult[o]
if isolation_active:
target_dev = out_conds[cond_index].device
if hasattr(out_t, "device") and out_t.device != target_dev:
out_t = out_t.to(target_dev)
if hasattr(mult_t, "device") and mult_t.device != target_dev:
mult_t = mult_t.to(target_dev)
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
out_conds[cond_index] += out_t * mult_t
out_counts[cond_index] += mult_t
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
@ -337,8 +356,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
out_c += out_t * mult_t
out_cts += mult_t
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
@ -392,14 +411,31 @@ class KSamplerX0Inpaint:
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
if isolation_active:
latent_image = self.latent_image
if hasattr(latent_image, "device") and latent_image.device != x.device:
latent_image = latent_image.to(x.device)
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
if hasattr(scaled, "device") and scaled.device != x.device:
scaled = scaled.to(x.device)
else:
scaled = self.inner_model.inner_model.scale_latent_inpaint(
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
)
x = x * denoise_mask + scaled * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
latent_image = self.latent_image
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
latent_image = latent_image.to(out.device)
out = out * denoise_mask + latent_image * latent_mask
return out
def simple_scheduler(model_sampling, steps):

View File

@ -92,7 +92,7 @@ if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
else:
elif not args.use_process_isolation:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@ -1,7 +1,9 @@
import copy
import gc
import heapq
import inspect
import logging
import os
import sys
import threading
import time
@ -261,20 +263,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
# Check for isolated node - skip validation and class cloning
if hasattr(obj, "_pyisolate_extension"):
# Isolated Node: The stub is just a proxy; real validation happens in child process
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# Inject hidden inputs so they're available in the isolated child process
inputs.update(v3_data.get("hidden_inputs", {}))
f = getattr(obj, func)
# Standard V3 Node (Existing Logic)
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@ -536,6 +549,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True)
unblock()
# Keep isolation node execution deterministic by default, but allow
# opt-out for diagnostics.
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
if args.use_process_isolation and isolation_sequential:
await await_completion()
return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
@ -647,6 +668,22 @@ class PromptExecutor:
self.status_messages = []
self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None:
try:
from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state()
except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool):
data = {
**data,
@ -688,6 +725,17 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# Update RPC event loops for all isolated extensions
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
@ -701,6 +749,20 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
@ -727,6 +789,13 @@ class PromptExecutor:
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
if args.use_process_isolation:
pending_class_types = set()
for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
@ -757,6 +826,7 @@ class PromptExecutor:
"outputs": ui_outputs,
"meta": meta_outputs,
}
comfy.model_management.cleanup_models_gc()
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()

View File

@ -33,3 +33,5 @@ pydantic-settings~=2.0
PyOpenGL
PyOpenGL-accelerate
glfw
pyisolate==0.9.1