mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
fix(isolation-lifecycle): execution/model ejection parity + fenced sampler device handling
add pyisolate==0.9.1 to requirements.txt
This commit is contained in:
parent
a1c3124821
commit
3c8ba051b6
@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
|||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
from node_helpers import conditioning_set_values
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
# #######################################################################################################
|
# #######################################################################################################
|
||||||
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
|||||||
HookedOnly = "hooked_only"
|
HookedOnly = "hooked_only"
|
||||||
|
|
||||||
|
|
||||||
|
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
class _HookRef:
|
||||||
pass
|
def __init__(self):
|
||||||
|
if _ISOLATION_HOOKREF_MODE:
|
||||||
|
self._pyisolate_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def _ensure_pyisolate_id(self):
|
||||||
|
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||||
|
if pyisolate_id is None:
|
||||||
|
pyisolate_id = str(uuid.uuid4())
|
||||||
|
self._pyisolate_id = pyisolate_id
|
||||||
|
return pyisolate_id
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return self is other
|
||||||
|
if not isinstance(other, _HookRef):
|
||||||
|
return False
|
||||||
|
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return id(self)
|
||||||
|
return hash(self._ensure_pyisolate_id())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return super().__str__()
|
||||||
|
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||||
|
|
||||||
|
|
||||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
@ -168,6 +200,8 @@ class WeightHook(Hook):
|
|||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
else:
|
else:
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if self.weights is None:
|
||||||
|
self.weights = {}
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
else:
|
else:
|
||||||
if target == EnumWeightTarget.Clip:
|
if target == EnumWeightTarget.Clip:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
@ -12,8 +13,8 @@ from . import deis
|
|||||||
from . import sa_solver
|
from . import sa_solver
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
from comfy.cli_args import args
|
||||||
from comfy.utils import model_trange as trange
|
from comfy.utils import model_trange as trange
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||||
|
if isolation_active:
|
||||||
|
target_device = sigmas.device
|
||||||
|
if x.device != target_device:
|
||||||
|
x = x.to(target_device)
|
||||||
|
s_in = s_in.to(target_device)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
if s_churn > 0:
|
if s_churn > 0:
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
|||||||
@ -76,7 +76,6 @@ class ModelType(Enum):
|
|||||||
FLUX = 8
|
FLUX = 8
|
||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
FLOW_COSMOS = 10
|
FLOW_COSMOS = 10
|
||||||
IMG_TO_IMG_FLOW = 11
|
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -109,11 +108,17 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.FLOW_COSMOS:
|
elif model_type == ModelType.FLOW_COSMOS:
|
||||||
c = comfy.model_sampling.COSMOS_RFLOW
|
c = comfy.model_sampling.COSMOS_RFLOW
|
||||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
|
||||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
def __reduce__(self):
|
||||||
|
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||||
|
try:
|
||||||
|
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||||
|
registry = ModelSamplingRegistry()
|
||||||
|
ms_id = registry.register(self)
|
||||||
|
return (ModelSamplingProxy, (ms_id,))
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||||
|
|
||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
@ -974,10 +979,6 @@ class LTXV(BaseModel):
|
|||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
|
||||||
if guide_attention_entries is not None:
|
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
@ -1030,10 +1031,6 @@ class LTXAV(BaseModel):
|
|||||||
if latent_shapes is not None:
|
if latent_shapes is not None:
|
||||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
|
||||||
if guide_attention_entries is not None:
|
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
@ -1477,12 +1474,6 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
class WAN21_FlowRVS(WAN21):
|
|
||||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
|
||||||
model_config.unet_config["model_type"] = "t2v"
|
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
|
||||||
self.image_to_video = image_to_video
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@ -350,7 +350,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
@ -378,7 +378,7 @@ try:
|
|||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if rocm_version >= (7, 0):
|
if rocm_version >= (7, 0):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||||
@ -570,7 +570,13 @@ class LoadedModel:
|
|||||||
self._patcher_finalizer.detach()
|
self._patcher_finalizer.detach()
|
||||||
|
|
||||||
def is_dead(self):
|
def is_dead(self):
|
||||||
return self.real_model() is not None and self.model is None
|
# Model is dead if the weakref to model has been garbage collected
|
||||||
|
# This can happen with ModelPatcherProxy objects between isolated workflows
|
||||||
|
if self.model is None:
|
||||||
|
return True
|
||||||
|
if self.real_model is None:
|
||||||
|
return False
|
||||||
|
return self.real_model() is None
|
||||||
|
|
||||||
|
|
||||||
def use_more_memory(extra_memory, loaded_models, device):
|
def use_more_memory(extra_memory, loaded_models, device):
|
||||||
@ -616,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
|
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||||
|
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
@ -624,6 +631,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
|
if can_unload and isolation_active:
|
||||||
|
try:
|
||||||
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
flush_tensor_keeper = None
|
||||||
|
if callable(flush_tensor_keeper):
|
||||||
|
flushed = flush_tensor_keeper()
|
||||||
|
if flushed > 0:
|
||||||
|
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
@ -645,7 +663,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded = current_loaded_models.pop(i)
|
||||||
|
model_obj = unloaded.model
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
|
unloaded_models.append(unloaded)
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
@ -767,25 +791,28 @@ def loaded_models(only_currently_used=False):
|
|||||||
|
|
||||||
|
|
||||||
def cleanup_models_gc():
|
def cleanup_models_gc():
|
||||||
do_gc = False
|
|
||||||
|
|
||||||
reset_cast_buffers()
|
reset_cast_buffers()
|
||||||
|
dead_found = False
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
cur = current_loaded_models[i]
|
if current_loaded_models[i].is_dead():
|
||||||
if cur.is_dead():
|
dead_found = True
|
||||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
|
||||||
do_gc = True
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if do_gc:
|
if dead_found:
|
||||||
|
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
cur = current_loaded_models[i]
|
cur = current_loaded_models[i]
|
||||||
if cur.is_dead():
|
if cur.is_dead():
|
||||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||||
|
leaked = current_loaded_models.pop(i)
|
||||||
|
model_obj = getattr(leaked, "model", None)
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
def archive_model_dtypes(model):
|
def archive_model_dtypes(model):
|
||||||
@ -802,6 +829,11 @@ def cleanup_models():
|
|||||||
|
|
||||||
for i in to_delete:
|
for i in to_delete:
|
||||||
x = current_loaded_models.pop(i)
|
x = current_loaded_models.pop(i)
|
||||||
|
model_obj = getattr(x, "model", None)
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
del x
|
del x
|
||||||
|
|
||||||
def dtype_size(dtype):
|
def dtype_size(dtype):
|
||||||
|
|||||||
@ -11,12 +11,14 @@ from functools import partial
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy.cli_args import args
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -213,6 +215,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
|||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@ -294,9 +297,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
patches = p.patches
|
patches = p.patches
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
if isolation_active:
|
||||||
|
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||||
|
input_x = torch.cat(input_x).to(target_device)
|
||||||
|
else:
|
||||||
|
input_x = torch.cat(input_x)
|
||||||
c = cond_cat(c)
|
c = cond_cat(c)
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
if isolation_active:
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||||
|
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||||
|
else:
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
@ -327,9 +338,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
cond_index = cond_or_uncond[o]
|
cond_index = cond_or_uncond[o]
|
||||||
a = area[o]
|
a = area[o]
|
||||||
|
out_t = output[o]
|
||||||
|
mult_t = mult[o]
|
||||||
|
if isolation_active:
|
||||||
|
target_dev = out_conds[cond_index].device
|
||||||
|
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||||
|
out_t = out_t.to(target_dev)
|
||||||
|
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||||
|
mult_t = mult_t.to(target_dev)
|
||||||
if a is None:
|
if a is None:
|
||||||
out_conds[cond_index] += output[o] * mult[o]
|
out_conds[cond_index] += out_t * mult_t
|
||||||
out_counts[cond_index] += mult[o]
|
out_counts[cond_index] += mult_t
|
||||||
else:
|
else:
|
||||||
out_c = out_conds[cond_index]
|
out_c = out_conds[cond_index]
|
||||||
out_cts = out_counts[cond_index]
|
out_cts = out_counts[cond_index]
|
||||||
@ -337,8 +356,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
for i in range(dims):
|
for i in range(dims):
|
||||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||||
out_c += output[o] * mult[o]
|
out_c += out_t * mult_t
|
||||||
out_cts += mult[o]
|
out_cts += mult_t
|
||||||
|
|
||||||
for i in range(len(out_conds)):
|
for i in range(len(out_conds)):
|
||||||
out_conds[i] /= out_counts[i]
|
out_conds[i] /= out_counts[i]
|
||||||
@ -392,14 +411,31 @@ class KSamplerX0Inpaint:
|
|||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.sigmas = sigmas
|
self.sigmas = sigmas
|
||||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
|
if isolation_active and denoise_mask.device != x.device:
|
||||||
|
denoise_mask = denoise_mask.to(x.device)
|
||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
if isolation_active:
|
||||||
|
latent_image = self.latent_image
|
||||||
|
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||||
|
latent_image = latent_image.to(x.device)
|
||||||
|
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||||
|
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||||
|
scaled = scaled.to(x.device)
|
||||||
|
else:
|
||||||
|
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||||
|
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||||
|
)
|
||||||
|
x = x * denoise_mask + scaled * latent_mask
|
||||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out = out * denoise_mask + self.latent_image * latent_mask
|
latent_image = self.latent_image
|
||||||
|
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||||
|
latent_image = latent_image.to(out.device)
|
||||||
|
out = out * denoise_mask + latent_image * latent_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model_sampling, steps):
|
def simple_scheduler(model_sampling, steps):
|
||||||
|
|||||||
@ -92,7 +92,7 @@ if args.cuda_malloc:
|
|||||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
if env_var is None:
|
if env_var is None:
|
||||||
env_var = "backend:cudaMallocAsync"
|
env_var = "backend:cudaMallocAsync"
|
||||||
else:
|
elif not args.use_process_isolation:
|
||||||
env_var += ",backend:cudaMallocAsync"
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
|
|||||||
96
execution.py
96
execution.py
@ -1,7 +1,9 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import gc
|
||||||
import heapq
|
import heapq
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -261,20 +263,31 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
# V3
|
# V3
|
||||||
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||||
# if is just a class, then assign no state, just create clone
|
# Check for isolated node - skip validation and class cloning
|
||||||
if is_class(obj):
|
if hasattr(obj, "_pyisolate_extension"):
|
||||||
type_obj = obj
|
# Isolated Node: The stub is just a proxy; real validation happens in child process
|
||||||
obj.VALIDATE_CLASS()
|
if v3_data is not None:
|
||||||
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||||
# otherwise, use class instance to populate/reuse some fields
|
# Inject hidden inputs so they're available in the isolated child process
|
||||||
|
inputs.update(v3_data.get("hidden_inputs", {}))
|
||||||
|
f = getattr(obj, func)
|
||||||
|
# Standard V3 Node (Existing Logic)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
type_obj = type(obj)
|
# if is just a class, then assign no resources or state, just create clone
|
||||||
type_obj.VALIDATE_CLASS()
|
if is_class(obj):
|
||||||
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
type_obj = obj
|
||||||
f = make_locked_method_func(type_obj, func, class_clone)
|
obj.VALIDATE_CLASS()
|
||||||
# in case of dynamic inputs, restructure inputs to expected nested dict
|
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
if v3_data is not None:
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
inputs = _io.build_nested_inputs(inputs, v3_data)
|
else:
|
||||||
|
type_obj = type(obj)
|
||||||
|
type_obj.VALIDATE_CLASS()
|
||||||
|
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
|
||||||
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
|
# in case of dynamic inputs, restructure inputs to expected nested dict
|
||||||
|
if v3_data is not None:
|
||||||
|
inputs = _io.build_nested_inputs(inputs, v3_data)
|
||||||
# V1
|
# V1
|
||||||
else:
|
else:
|
||||||
f = getattr(obj, func)
|
f = getattr(obj, func)
|
||||||
@ -536,6 +549,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
unblock()
|
unblock()
|
||||||
|
|
||||||
|
# Keep isolation node execution deterministic by default, but allow
|
||||||
|
# opt-out for diagnostics.
|
||||||
|
isolation_sequential = os.environ.get("COMFY_ISOLATE_SEQUENTIAL", "1").lower() in ("1", "true", "yes")
|
||||||
|
if args.use_process_isolation and isolation_sequential:
|
||||||
|
await await_completion()
|
||||||
|
return await execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
|
||||||
|
|
||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
@ -647,6 +668,22 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
|
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
|
||||||
|
try:
|
||||||
|
from comfy.isolation import notify_execution_graph
|
||||||
|
await notify_execution_graph(class_types)
|
||||||
|
except Exception:
|
||||||
|
if fail_loud:
|
||||||
|
raise
|
||||||
|
logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
|
||||||
|
|
||||||
|
async def _flush_running_extensions_transport_state_safe(self) -> None:
|
||||||
|
try:
|
||||||
|
from comfy.isolation import flush_running_extensions_transport_state
|
||||||
|
await flush_running_extensions_transport_state()
|
||||||
|
except Exception:
|
||||||
|
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
|
||||||
|
|
||||||
def add_message(self, event, data: dict, broadcast: bool):
|
def add_message(self, event, data: dict, broadcast: bool):
|
||||||
data = {
|
data = {
|
||||||
**data,
|
**data,
|
||||||
@ -688,6 +725,17 @@ class PromptExecutor:
|
|||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
|
# Update RPC event loops for all isolated extensions
|
||||||
|
# This is critical for serial workflow execution - each asyncio.run() creates
|
||||||
|
# a new event loop, and RPC instances must be updated to use it
|
||||||
|
try:
|
||||||
|
from comfy.isolation import update_rpc_event_loops
|
||||||
|
update_rpc_event_loops()
|
||||||
|
except ImportError:
|
||||||
|
pass # Isolation not available
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
|
||||||
|
|
||||||
set_preview_method(extra_data.get("preview_method"))
|
set_preview_method(extra_data.get("preview_method"))
|
||||||
|
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
@ -701,6 +749,20 @@ class PromptExecutor:
|
|||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
if args.use_process_isolation:
|
||||||
|
try:
|
||||||
|
# Boundary cleanup runs at the start of the next workflow in
|
||||||
|
# isolation mode, matching non-isolated "next prompt" timing.
|
||||||
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
|
await self._flush_running_extensions_transport_state_safe()
|
||||||
|
comfy.model_management.unload_all_models()
|
||||||
|
comfy.model_management.cleanup_models_gc()
|
||||||
|
comfy.model_management.cleanup_models()
|
||||||
|
gc.collect()
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
except Exception:
|
||||||
|
logging.debug("][ EX:isolation_boundary_cleanup_start failed", exc_info=True)
|
||||||
|
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
reset_progress_state(prompt_id, dynamic_prompt)
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
@ -727,6 +789,13 @@ class PromptExecutor:
|
|||||||
for node_id in list(execute_outputs):
|
for node_id in list(execute_outputs):
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
|
|
||||||
|
if args.use_process_isolation:
|
||||||
|
pending_class_types = set()
|
||||||
|
for node_id in execution_list.pendingNodes.keys():
|
||||||
|
class_type = dynamic_prompt.get_node(node_id)["class_type"]
|
||||||
|
pending_class_types.add(class_type)
|
||||||
|
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
|
||||||
|
|
||||||
while not execution_list.is_empty():
|
while not execution_list.is_empty():
|
||||||
node_id, error, ex = await execution_list.stage_node_execution()
|
node_id, error, ex = await execution_list.stage_node_execution()
|
||||||
if error is not None:
|
if error is not None:
|
||||||
@ -757,6 +826,7 @@ class PromptExecutor:
|
|||||||
"outputs": ui_outputs,
|
"outputs": ui_outputs,
|
||||||
"meta": meta_outputs,
|
"meta": meta_outputs,
|
||||||
}
|
}
|
||||||
|
comfy.model_management.cleanup_models_gc()
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|||||||
@ -33,3 +33,5 @@ pydantic-settings~=2.0
|
|||||||
PyOpenGL
|
PyOpenGL
|
||||||
PyOpenGL-accelerate
|
PyOpenGL-accelerate
|
||||||
glfw
|
glfw
|
||||||
|
|
||||||
|
pyisolate==0.9.1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user