fix: update isolation environment variable references to use PYISOLATE_CHILD
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
John Pollock 2026-03-04 17:21:15 -06:00
parent 8322f39219
commit 403f6210eb
6 changed files with 92 additions and 39 deletions

View File

@ -64,7 +64,7 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class _HookRef:

View File

@ -192,7 +192,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if isolation_active:
target_device = sigmas.device
if x.device != target_device:

View File

@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import os
import comfy.ldm.lightricks.av_model
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
@ -112,16 +113,20 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
from comfy.cli_args import args
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class ModelSampling(s, c):
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
if isolation_runtime_enabled:
def __reduce__(self):
"""Ensure pickling yields a proxy instead of failing on local class."""
try:
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
registry = ModelSamplingRegistry()
ms_id = registry.register(self)
return (ModelSamplingProxy, (ms_id,))
except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config)

View File

@ -483,6 +483,9 @@ except:
current_loaded_models = []
def _isolation_mode_enabled():
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -562,8 +565,9 @@ class LoadedModel:
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
if self.model_finalizer is not None:
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
@ -577,14 +581,15 @@ class LoadedModel:
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def dead_state(self):
model_ref_gone = self.model is None
real_model_ref = self.real_model
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
return model_ref_gone, real_model_ref_gone
def is_dead(self):
# Model is dead if the weakref to model has been garbage collected
# This can happen with ModelPatcherProxy objects between isolated workflows
if self.model is None:
return True
if self.real_model is None:
return False
return self.real_model() is None
model_ref_gone, real_model_ref_gone = self.dead_state()
return model_ref_gone or real_model_ref_gone
def use_more_memory(extra_memory, loaded_models, device):
@ -630,7 +635,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
unloaded_model = []
can_unload = []
unloaded_models = []
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
isolation_active = _isolation_mode_enabled()
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@ -735,7 +740,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for i in to_unload:
model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
if model_to_unload.model_finalizer is not None:
model_to_unload.model_finalizer.detach()
model_to_unload.model_finalizer = None
total_memory_required = {}
@ -799,21 +806,55 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
reset_cast_buffers()
if not _isolation_mode_enabled():
dead_found = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
dead_found = True
break
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
return
dead_found = False
has_real_model_leak = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
if model_ref_gone or real_model_ref_gone:
dead_found = True
break
if real_model_ref_gone:
has_real_model_leak = True
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
if has_real_model_leak:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
else:
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
model_ref_gone, real_model_ref_gone = cur.dead_state()
if model_ref_gone or real_model_ref_gone:
if real_model_ref_gone:
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
else:
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:

View File

@ -216,7 +216,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
out_conds = []
out_counts = []
# separate conds by matching hooks
@ -413,7 +413,7 @@ class KSamplerX0Inpaint:
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device)

View File

@ -681,6 +681,8 @@ class PromptExecutor:
self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types)
@ -690,6 +692,8 @@ class PromptExecutor:
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state()
@ -703,6 +707,8 @@ class PromptExecutor:
timeout_ms: int = 120000,
marker: str = "EX:wait_model_patcher_idle",
) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import wait_for_model_patcher_quiescence
@ -755,16 +761,17 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# Update RPC event loops for all isolated extensions
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
if args.use_process_isolation:
# Update RPC event loops for all isolated extensions.
# This is critical for serial workflow execution - each asyncio.run() creates
# a new event loop, and RPC instances must be updated to use it.
try:
from comfy.isolation import update_rpc_event_loops
update_rpc_event_loops()
except ImportError:
pass # Isolation not available
except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method"))