diff --git a/comfy/hooks.py b/comfy/hooks.py index 1a76c7ba4..02111ce4d 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: import comfy.lora import comfy.model_management import comfy.patcher_extension +from comfy.cli_args import args +import uuid +import os from node_helpers import conditioning_set_values # ####################################################################################################### @@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum): HookedOnly = "hooked_only" +_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + + class _HookRef: - pass + def __init__(self): + if _ISOLATION_HOOKREF_MODE: + self._pyisolate_id = str(uuid.uuid4()) + + def _ensure_pyisolate_id(self): + pyisolate_id = getattr(self, "_pyisolate_id", None) + if pyisolate_id is None: + pyisolate_id = str(uuid.uuid4()) + self._pyisolate_id = pyisolate_id + return pyisolate_id + + def __eq__(self, other): + if not _ISOLATION_HOOKREF_MODE: + return self is other + if not isinstance(other, _HookRef): + return False + return self._ensure_pyisolate_id() == other._ensure_pyisolate_id() + + def __hash__(self): + if not _ISOLATION_HOOKREF_MODE: + return id(self) + return hash(self._ensure_pyisolate_id()) + + def __str__(self): + if not _ISOLATION_HOOKREF_MODE: + return super().__str__() + return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}" def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -168,6 +200,8 @@ class WeightHook(Hook): key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) else: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if self.weights is None: + self.weights = {} weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: if target == EnumWeightTarget.Clip: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..c4299a3cd 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +import os from functools import partial from scipy import integrate @@ -12,8 +13,8 @@ from . import deis from . import sa_solver import comfy.model_patcher import comfy.model_sampling - import comfy.memory_management +from comfy.cli_args import args from comfy.utils import model_trange as trange def append_zero(x): @@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" + if isolation_active: + target_device = sigmas.device + if x.device != target_device: + x = x.to(target_device) + s_in = s_in.to(target_device) + for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. diff --git a/comfy/model_base.py b/comfy/model_base.py index 8f852e3c6..05bdaa242 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -76,7 +76,6 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 - IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -109,11 +108,17 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow - elif model_type == ModelType.IMG_TO_IMG_FLOW: - c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): - pass + def __reduce__(self): + """Ensure pickling yields a proxy instead of failing on local class.""" + try: + from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy + registry = ModelSamplingRegistry() + ms_id = registry.register(self) + return (ModelSamplingProxy, (ms_id,)) + except Exception as exc: + raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc return ModelSampling(model_config) @@ -974,10 +979,6 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) - guide_attention_entries = kwargs.get("guide_attention_entries", None) - if guide_attention_entries is not None: - out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) - return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1030,10 +1031,6 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) - guide_attention_entries = kwargs.get("guide_attention_entries", None) - if guide_attention_entries is not None: - out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) - return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): @@ -1477,12 +1474,6 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image -class WAN21_FlowRVS(WAN21): - def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): - model_config.unet_config["model_type"] = "t2v" - super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) - self.image_to_video = image_to_video - class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_management.py b/comfy/model_management.py index f73613f17..eed83f2fe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -378,7 +378,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]): @@ -570,7 +570,13 @@ class LoadedModel: self._patcher_finalizer.detach() def is_dead(self): - return self.real_model() is not None and self.model is None + # Model is dead if the weakref to model has been garbage collected + # This can happen with ModelPatcherProxy objects between isolated workflows + if self.model is None: + return True + if self.real_model is None: + return False + return self.real_model() is None def use_more_memory(extra_memory, loaded_models, device): @@ -616,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ unloaded_model = [] can_unload = [] unloaded_models = [] + isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] @@ -624,6 +631,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False + if can_unload and isolation_active: + try: + from pyisolate import flush_tensor_keeper # type: ignore[attr-defined] + except Exception: + flush_tensor_keeper = None + if callable(flush_tensor_keeper): + flushed = flush_tensor_keeper() + if flushed > 0: + logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed) + gc.collect() + for x in sorted(can_unload): i = x[-1] memory_to_free = 1e32 @@ -645,7 +663,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + unloaded = current_loaded_models.pop(i) + model_obj = unloaded.model + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() + unloaded_models.append(unloaded) if len(unloaded_model) > 0: soft_empty_cache() @@ -767,25 +791,28 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): - do_gc = False - reset_cast_buffers() - + dead_found = False for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.is_dead(): - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True + if current_loaded_models[i].is_dead(): + dead_found = True break - if do_gc: + if dead_found: + logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") gc.collect() soft_empty_cache() - for i in range(len(current_loaded_models)): + for i in range(len(current_loaded_models) - 1, -1, -1): cur = current_loaded_models[i] if cur.is_dead(): - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") + leaked = current_loaded_models.pop(i) + model_obj = getattr(leaked, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() def archive_model_dtypes(model): @@ -802,6 +829,11 @@ def cleanup_models(): for i in to_delete: x = current_loaded_models.pop(i) + model_obj = getattr(x, "model", None) + if model_obj is not None: + cleanup = getattr(model_obj, "cleanup", None) + if callable(cleanup): + cleanup() del x def dtype_size(dtype): diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..b9aa8e175 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,12 +11,14 @@ from functools import partial import collections import math import logging +import os import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows import comfy.utils +from comfy.cli_args import args import scipy.stats import numpy @@ -213,6 +215,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc return executor.execute(model, conds, x_in, timestep, model_options) def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" out_conds = [] out_counts = [] # separate conds by matching hooks @@ -294,9 +297,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens patches = p.patches batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) + if isolation_active: + target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device + input_x = torch.cat(input_x).to(target_device) + else: + input_x = torch.cat(input_x) c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) + if isolation_active: + timestep_ = torch.cat([timestep] * batch_chunks).to(target_device) + mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult] + else: + timestep_ = torch.cat([timestep] * batch_chunks) transformer_options = model.current_patcher.apply_hooks(hooks=hooks) if 'transformer_options' in model_options: @@ -327,9 +338,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] + out_t = output[o] + mult_t = mult[o] + if isolation_active: + target_dev = out_conds[cond_index].device + if hasattr(out_t, "device") and out_t.device != target_dev: + out_t = out_t.to(target_dev) + if hasattr(mult_t, "device") and mult_t.device != target_dev: + mult_t = mult_t.to(target_dev) if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] + out_conds[cond_index] += out_t * mult_t + out_counts[cond_index] += mult_t else: out_c = out_conds[cond_index] out_cts = out_counts[cond_index] @@ -337,8 +356,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens for i in range(dims): out_c = out_c.narrow(i + 2, a[i + dims], a[i]) out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + out_c += out_t * mult_t + out_cts += mult_t for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -392,14 +411,31 @@ class KSamplerX0Inpaint: self.inner_model = model self.sigmas = sigmas def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): + isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" if denoise_mask is not None: + if isolation_active and denoise_mask.device != x.device: + denoise_mask = denoise_mask.to(x.device) if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask + if isolation_active: + latent_image = self.latent_image + if hasattr(latent_image, "device") and latent_image.device != x.device: + latent_image = latent_image.to(x.device) + scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image) + if hasattr(scaled, "device") and scaled.device != x.device: + scaled = scaled.to(x.device) + else: + scaled = self.inner_model.inner_model.scale_latent_inpaint( + x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image + ) + x = x * denoise_mask + scaled * latent_mask out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: - out = out * denoise_mask + self.latent_image * latent_mask + latent_image = self.latent_image + if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device: + latent_image = latent_image.to(out.device) + out = out * denoise_mask + latent_image * latent_mask return out def simple_scheduler(model_sampling, steps): diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c..f6d2063e9 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -92,7 +92,7 @@ if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" - else: + elif not args.use_process_isolation: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var diff --git a/execution.py b/execution.py index 75b021892..d0e457f11 100644 --- a/execution.py +++ b/execution.py @@ -1,7 +1,9 @@ import copy +import gc import heapq import inspect import logging +import os import sys import threading import time @@ -261,20 +263,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no state, just create clone - if is_class(obj): - type_obj = obj - obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(v3_data) - # otherwise, use class instance to populate/reuse some fields + # Check for isolated node - skip validation and class cloning + if hasattr(obj, "_pyisolate_extension"): + # Isolated Node: The stub is just a proxy; real validation happens in child process + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) + # Inject hidden inputs so they're available in the isolated child process + inputs.update(v3_data.get("hidden_inputs", {})) + f = getattr(obj, func) + # Standard V3 Node (Existing Logic) + else: - type_obj = type(obj) - type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) - f = make_locked_method_func(type_obj, func, class_clone) - # in case of dynamic inputs, restructure inputs to expected nested dict - if v3_data is not None: - inputs = _io.build_nested_inputs(inputs, v3_data) + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) + f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -536,6 +549,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + + # Keep isolation node execution deterministic by default, but allow + # opt-out for diagnostics. + isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes") + if args.use_process_isolation and isolation_sequential: + await await_completion() + return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs) + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -647,6 +668,22 @@ class PromptExecutor: self.status_messages = [] self.success = True + async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: + try: + from comfy.isolation import notify_execution_graph + await notify_execution_graph(class_types) + except Exception: + if fail_loud: + raise + logging.debug("][ EX:notify_execution_graph failed", exc_info=True) + + async def _flush_running_extensions_transport_state_safe(self) -> None: + try: + from comfy.isolation import flush_running_extensions_transport_state + await flush_running_extensions_transport_state() + except Exception: + logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) + def add_message(self, event, data: dict, broadcast: bool): data = { **data, @@ -688,6 +725,17 @@ class PromptExecutor: asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + # Update RPC event loops for all isolated extensions + # This is critical for serial workflow execution - each asyncio.run() creates + # a new event loop, and RPC instances must be updated to use it + try: + from comfy.isolation import update_rpc_event_loops + update_rpc_event_loops() + except ImportError: + pass # Isolation not available + except Exception as e: + logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") + set_preview_method(extra_data.get("preview_method")) nodes.interrupt_processing(False) @@ -701,6 +749,20 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): + if args.use_process_isolation: + try: + # Boundary cleanup runs at the start of the next workflow in + # isolation mode, matching non-isolated "next prompt" timing. + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + await self._flush_running_extensions_transport_state_safe() + comfy.model_management.unload_all_models() + comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models() + gc.collect() + comfy.model_management.soft_empty_cache() + except Exception: + logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True) + dynamic_prompt = DynamicPrompt(prompt) reset_progress_state(prompt_id, dynamic_prompt) add_progress_handler(WebUIProgressHandler(self.server)) @@ -727,6 +789,13 @@ class PromptExecutor: for node_id in list(execute_outputs): execution_list.add_node(node_id) + if args.use_process_isolation: + pending_class_types = set() + for node_id in execution_list.pendingNodes.keys(): + class_type = dynamic_prompt.get_node(node_id)["class_type"] + pending_class_types.add(class_type) + await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: @@ -757,6 +826,7 @@ class PromptExecutor: "outputs": ui_outputs, "meta": meta_outputs, } + comfy.model_management.cleanup_models_gc() self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() diff --git a/requirements.txt b/requirements.txt index b5b292980..632f067da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,5 @@ pydantic-settings~=2.0 PyOpenGL PyOpenGL-accelerate glfw + +pyisolate==0.9.1