diff --git a/comfy/hooks.py b/comfy/hooks.py index 02111ce4d..7a5f69ca7 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -64,7 +64,7 @@ class EnumHookScope(enum.Enum): HookedOnly = "hooked_only" -_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" +_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" class _HookRef: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c4299a3cd..4ed4a9250 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -192,7 +192,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) - isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" if isolation_active: target_device = sigmas.device if x.device != target_device: diff --git a/comfy/model_base.py b/comfy/model_base.py index 3783dc88c..fcf0bd93e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import os import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -112,16 +113,20 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + from comfy.cli_args import args + isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + class ModelSampling(s, c): - def __reduce__(self): - """Ensure pickling yields a proxy instead of failing on local class.""" - try: - from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy - registry = ModelSamplingRegistry() - ms_id = registry.register(self) - return (ModelSamplingProxy, (ms_id,)) - except Exception as exc: - raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc + if isolation_runtime_enabled: + def __reduce__(self): + """Ensure pickling yields a proxy instead of failing on local class.""" + try: + from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy + registry = ModelSamplingRegistry() + ms_id = registry.register(self) + return (ModelSamplingProxy, (ms_id,)) + except Exception as exc: + raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc return ModelSampling(model_config) diff --git a/comfy/model_management.py b/comfy/model_management.py index f472f7fe9..89e2d6c17 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -483,6 +483,9 @@ except: current_loaded_models = [] +def _isolation_mode_enabled(): + return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -562,8 +565,9 @@ class LoadedModel: if freed >= memory_to_free: return False self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None + if self.model_finalizer is not None: + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True @@ -577,14 +581,15 @@ class LoadedModel: if self._patcher_finalizer is not None: self._patcher_finalizer.detach() + def dead_state(self): + model_ref_gone = self.model is None + real_model_ref = self.real_model + real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None + return model_ref_gone, real_model_ref_gone + def is_dead(self): - # Model is dead if the weakref to model has been garbage collected - # This can happen with ModelPatcherProxy objects between isolated workflows - if self.model is None: - return True - if self.real_model is None: - return False - return self.real_model() is None + model_ref_gone, real_model_ref_gone = self.dead_state() + return model_ref_gone or real_model_ref_gone def use_more_memory(extra_memory, loaded_models, device): @@ -630,7 +635,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ unloaded_model = [] can_unload = [] unloaded_models = [] - isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + isolation_active = _isolation_mode_enabled() for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -735,7 +740,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu for i in to_unload: model_to_unload = current_loaded_models.pop(i) model_to_unload.model.detach(unpatch_all=False) - model_to_unload.model_finalizer.detach() + if model_to_unload.model_finalizer is not None: + model_to_unload.model_finalizer.detach() + model_to_unload.model_finalizer = None total_memory_required = {} @@ -799,21 +806,55 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): reset_cast_buffers() + if not _isolation_mode_enabled(): + dead_found = False + for i in range(len(current_loaded_models)): + if current_loaded_models[i].is_dead(): + dead_found = True + break + + if dead_found: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models) - 1, -1, -1): + cur = current_loaded_models[i] + if cur.is_dead(): + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + return + dead_found = False + has_real_model_leak = False for i in range(len(current_loaded_models)): - if current_loaded_models[i].is_dead(): + model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state() + if model_ref_gone or real_model_ref_gone: dead_found = True - break + if real_model_ref_gone: + has_real_model_leak = True if dead_found: - logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + if has_real_model_leak: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") + else: + logging.debug("Cleaning stale loaded-model entries with released patcher references.") gc.collect() soft_empty_cache() for i in range(len(current_loaded_models) - 1, -1, -1): cur = current_loaded_models[i] - if cur.is_dead(): - logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + model_ref_gone, real_model_ref_gone = cur.dead_state() + if model_ref_gone or real_model_ref_gone: + if real_model_ref_gone: + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + else: + logging.debug("Cleaning stale loaded-model entry with released patcher reference.") leaked = current_loaded_models.pop(i) model_obj = getattr(leaked, "model", None) if model_obj is not None: diff --git a/comfy/samplers.py b/comfy/samplers.py index 1b30f0064..b79ac575b 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -216,7 +216,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc return result def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): - isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" out_conds = [] out_counts = [] # separate conds by matching hooks @@ -413,7 +413,7 @@ class KSamplerX0Inpaint: self.inner_model = model self.sigmas = sigmas def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): - isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1" if denoise_mask is not None: if isolation_active and denoise_mask.device != x.device: denoise_mask = denoise_mask.to(x.device) diff --git a/execution.py b/execution.py index d04498a0d..5e64a6229 100644 --- a/execution.py +++ b/execution.py @@ -681,6 +681,8 @@ class PromptExecutor: self.success = True async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + if not args.use_process_isolation: + return try: from comfy.isolation import notify_execution_graph await notify_execution_graph(class_types) @@ -690,6 +692,8 @@ class PromptExecutor: logging.debug("][ EX:notify_execution_graph failed", exc_info=True) async def _flush_running_extensions_transport_state_safe(self) -> None: + if not args.use_process_isolation: + return try: from comfy.isolation import flush_running_extensions_transport_state await flush_running_extensions_transport_state() @@ -703,6 +707,8 @@ class PromptExecutor: timeout_ms: int = 120000, marker: str = "EX:wait_model_patcher_idle", ) -> None: + if not args.use_process_isolation: + return try: from comfy.isolation import wait_for_model_patcher_quiescence @@ -755,16 +761,17 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - # Update RPC event loops for all isolated extensions - # This is critical for serial workflow execution - each asyncio.run() creates - # a new event loop, and RPC instances must be updated to use it - try: - from comfy.isolation import update_rpc_event_loops - update_rpc_event_loops() - except ImportError: - pass # Isolation not available - except Exception as e: - logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + if args.use_process_isolation: + # Update RPC event loops for all isolated extensions. + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it. + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") set_preview_method(extra_data.get("preview_method"))