fix(isolation-lifecycle): execution/model ejection parity + fenced sampler device handling

add pyisolate==0.9.1 to requirements.txt
This commit is contained in:
John Pollock 2026-02-27 13:07:23 -06:00
parent a1c3124821
commit 3c8ba051b6
8 changed files with 229 additions and 56 deletions

View File

@ -14,6 +14,9 @@ if TYPE_CHECKING:
import comfy.lora import comfy.lora
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
from comfy.cli_args import args
import uuid
import os
from node_helpers import conditioning_set_values from node_helpers import conditioning_set_values
# ####################################################################################################### # #######################################################################################################
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only" HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
class _HookRef: class _HookRef:
pass def __init__(self):
if _ISOLATION_HOOKREF_MODE:
self._pyisolate_id = str(uuid.uuid4())
def _ensure_pyisolate_id(self):
pyisolate_id = getattr(self, "_pyisolate_id", None)
if pyisolate_id is None:
pyisolate_id = str(uuid.uuid4())
self._pyisolate_id = pyisolate_id
return pyisolate_id
def __eq__(self, other):
if not _ISOLATION_HOOKREF_MODE:
return self is other
if not isinstance(other, _HookRef):
return False
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
def __hash__(self):
if not _ISOLATION_HOOKREF_MODE:
return id(self)
return hash(self._ensure_pyisolate_id())
def __str__(self):
if not _ISOLATION_HOOKREF_MODE:
return super().__str__()
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
@ -168,6 +200,8 @@ class WeightHook(Hook):
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else: else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if self.weights is None:
self.weights = {}
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else: else:
if target == EnumWeightTarget.Clip: if target == EnumWeightTarget.Clip:

View File

@ -1,4 +1,5 @@
import math import math
import os
from functools import partial from functools import partial
from scipy import integrate from scipy import integrate
@ -12,8 +13,8 @@ from . import deis
from . import sa_solver from . import sa_solver
import comfy.model_patcher import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
import comfy.memory_management import comfy.memory_management
from comfy.cli_args import args
from comfy.utils import model_trange as trange from comfy.utils import model_trange as trange
def append_zero(x): def append_zero(x):
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:
x = x.to(target_device)
s_in = s_in.to(target_device)
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
if s_churn > 0: if s_churn > 0:
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.

View File

@ -76,7 +76,6 @@ class ModelType(Enum):
FLUX = 8 FLUX = 8
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -109,11 +108,17 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS: elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c): class ModelSampling(s, c):
pass def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config) return ModelSampling(model_config)
@ -974,10 +979,6 @@ class LTXV(BaseModel):
if keyframe_idxs is not None: if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -1030,10 +1031,6 @@ class LTXAV(BaseModel):
if latent_shapes is not None: if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1477,12 +1474,6 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try: try:
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@ -378,7 +378,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@ -570,7 +570,13 @@ class LoadedModel:
self._patcher_finalizer.detach() self._patcher_finalizer.detach()
def is_dead(self): def is_dead(self):
return self.real_model() is not None and self.model is None # Model is dead if the weakref to model has been garbage collected
# This can happen with ModelPatcherProxy objects between isolated workflows
if self.model is None:
return True
if self.real_model is None:
return False
return self.real_model() is None
def use_more_memory(extra_memory, loaded_models, device): def use_more_memory(extra_memory, loaded_models, device):
@ -616,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
@ -624,6 +631,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
if can_unload and isolation_active:
try:
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
except Exception:
flush_tensor_keeper = None
if callable(flush_tensor_keeper):
flushed = flush_tensor_keeper()
if flushed > 0:
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
gc.collect()
for x in sorted(can_unload): for x in sorted(can_unload):
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
@ -645,7 +663,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
current_loaded_models[i].model.partially_unload_ram(ram_to_free) current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded = current_loaded_models.pop(i)
model_obj = unloaded.model
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
unloaded_models.append(unloaded)
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
@ -767,25 +791,28 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc(): def cleanup_models_gc():
do_gc = False
reset_cast_buffers() reset_cast_buffers()
dead_found = False
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
cur = current_loaded_models[i] if current_loaded_models[i].is_dead():
if cur.is_dead(): dead_found = True
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break break
if do_gc: if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect() gc.collect()
soft_empty_cache() soft_empty_cache()
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i] cur = current_loaded_models[i]
if cur.is_dead(): if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
def archive_model_dtypes(model): def archive_model_dtypes(model):
@ -802,6 +829,11 @@ def cleanup_models():
for i in to_delete: for i in to_delete:
x = current_loaded_models.pop(i) x = current_loaded_models.pop(i)
model_obj = getattr(x, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
del x del x
def dtype_size(dtype): def dtype_size(dtype):

View File

@ -11,12 +11,14 @@ from functools import partial
import collections import collections
import math import math
import logging import logging
import os
import comfy.sampler_helpers import comfy.sampler_helpers
import comfy.model_patcher import comfy.model_patcher
import comfy.patcher_extension import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import comfy.context_windows import comfy.context_windows
import comfy.utils import comfy.utils
from comfy.cli_args import args
import scipy.stats import scipy.stats
import numpy import numpy
@ -213,6 +215,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
return executor.execute(model, conds, x_in, timestep, model_options) return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
out_conds = [] out_conds = []
out_counts = [] out_counts = []
# separate conds by matching hooks # separate conds by matching hooks
@ -294,9 +297,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
patches = p.patches patches = p.patches
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) if isolation_active:
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
input_x = torch.cat(input_x).to(target_device)
else:
input_x = torch.cat(input_x)
c = cond_cat(c) c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks) if isolation_active:
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
else:
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = model.current_patcher.apply_hooks(hooks=hooks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
@ -327,9 +338,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for o in range(batch_chunks): for o in range(batch_chunks):
cond_index = cond_or_uncond[o] cond_index = cond_or_uncond[o]
a = area[o] a = area[o]
out_t = output[o]
mult_t = mult[o]
if isolation_active:
target_dev = out_conds[cond_index].device
if hasattr(out_t, "device") and out_t.device != target_dev:
out_t = out_t.to(target_dev)
if hasattr(mult_t, "device") and mult_t.device != target_dev:
mult_t = mult_t.to(target_dev)
if a is None: if a is None:
out_conds[cond_index] += output[o] * mult[o] out_conds[cond_index] += out_t * mult_t
out_counts[cond_index] += mult[o] out_counts[cond_index] += mult_t
else: else:
out_c = out_conds[cond_index] out_c = out_conds[cond_index]
out_cts = out_counts[cond_index] out_cts = out_counts[cond_index]
@ -337,8 +356,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for i in range(dims): for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o] out_c += out_t * mult_t
out_cts += mult[o] out_cts += mult_t
for i in range(len(out_conds)): for i in range(len(out_conds)):
out_conds[i] /= out_counts[i] out_conds[i] /= out_counts[i]
@ -392,14 +411,31 @@ class KSamplerX0Inpaint:
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
if denoise_mask is not None: if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)
if "denoise_mask_function" in model_options: if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask if isolation_active:
latent_image = self.latent_image
if hasattr(latent_image, "device") and latent_image.device != x.device:
latent_image = latent_image.to(x.device)
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
if hasattr(scaled, "device") and scaled.device != x.device:
scaled = scaled.to(x.device)
else:
scaled = self.inner_model.inner_model.scale_latent_inpaint(
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
)
x = x * denoise_mask + scaled * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask latent_image = self.latent_image
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
latent_image = latent_image.to(out.device)
out = out * denoise_mask + latent_image * latent_mask
return out return out
def simple_scheduler(model_sampling, steps): def simple_scheduler(model_sampling, steps):

View File

@ -92,7 +92,7 @@ if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None: if env_var is None:
env_var = "backend:cudaMallocAsync" env_var = "backend:cudaMallocAsync"
else: elif not args.use_process_isolation:
env_var += ",backend:cudaMallocAsync" env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var

View File

@ -1,7 +1,9 @@
import copy import copy
import gc
import heapq import heapq
import inspect import inspect
import logging import logging
import os
import sys import sys
import threading import threading
import time import time
@ -261,20 +263,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index) pre_execute_cb(index)
# V3 # V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no state, just create clone # Check for isolated node - skip validation and class cloning
if is_class(obj): if hasattr(obj, "_pyisolate_extension"):
type_obj = obj # Isolated Node: The stub is just a proxy; real validation happens in child process
obj.VALIDATE_CLASS() if v3_data is not None:
class_clone = obj.PREPARE_CLASS_CLONE(v3_data) inputs = _io.build_nested_inputs(inputs, v3_data)
# otherwise, use class instance to populate/reuse some fields # Inject hidden inputs so they're available in the isolated child process
inputs.update(v3_data.get("hidden_inputs", {}))
f = getattr(obj, func)
# Standard V3 Node (Existing Logic)
else: else:
type_obj = type(obj) # if is just a class, then assign no resources or state, just create clone
type_obj.VALIDATE_CLASS() if is_class(obj):
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) type_obj = obj
f = make_locked_method_func(type_obj, func, class_clone) obj.VALIDATE_CLASS()
# in case of dynamic inputs, restructure inputs to expected nested dict class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
if v3_data is not None: # otherwise, use class instance to populate/reuse some fields
inputs = _io.build_nested_inputs(inputs, v3_data) else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1 # V1
else: else:
f = getattr(obj, func) f = getattr(obj, func)
@ -536,6 +549,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
tasks = [x for x in output_data if isinstance(x, asyncio.Task)] tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
unblock() unblock()
# Keep isolation node execution deterministic by default, but allow
# opt-out for diagnostics.
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
if args.use_process_isolation and isolation_sequential:
await await_completion()
return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
@ -647,6 +668,22 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.success = True self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None:
try:
from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state()
except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = { data = {
**data, **data,
@ -688,6 +725,17 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# Update RPC event loops for all isolated extensions
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method")) set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
@ -701,6 +749,20 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode():
if args.use_process_isolation:
try:
# Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc()
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
except Exception:
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server)) add_progress_handler(WebUIProgressHandler(self.server))
@ -727,6 +789,13 @@ class PromptExecutor:
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
execution_list.add_node(node_id) execution_list.add_node(node_id)
if args.use_process_isolation:
pending_class_types = set()
for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty(): while not execution_list.is_empty():
node_id, error, ex = await execution_list.stage_node_execution() node_id, error, ex = await execution_list.stage_node_execution()
if error is not None: if error is not None:
@ -757,6 +826,7 @@ class PromptExecutor:
"outputs": ui_outputs, "outputs": ui_outputs,
"meta": meta_outputs, "meta": meta_outputs,
} }
comfy.model_management.cleanup_models_gc()
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()

View File

@ -33,3 +33,5 @@ pydantic-settings~=2.0
PyOpenGL PyOpenGL
PyOpenGL-accelerate PyOpenGL-accelerate
glfw glfw
pyisolate==0.9.1