mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
feat(isolation): sandbox policy and runtime fencing
This commit is contained in:
parent
fbb6be5624
commit
c9ebc6aa57
@ -1,2 +1,2 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -14,6 +14,9 @@ if TYPE_CHECKING:
|
|||||||
import comfy.lora
|
import comfy.lora
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
from node_helpers import conditioning_set_values
|
from node_helpers import conditioning_set_values
|
||||||
|
|
||||||
# #######################################################################################################
|
# #######################################################################################################
|
||||||
@ -61,8 +64,37 @@ class EnumHookScope(enum.Enum):
|
|||||||
HookedOnly = "hooked_only"
|
HookedOnly = "hooked_only"
|
||||||
|
|
||||||
|
|
||||||
|
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
|
||||||
|
|
||||||
class _HookRef:
|
class _HookRef:
|
||||||
pass
|
def __init__(self):
|
||||||
|
if _ISOLATION_HOOKREF_MODE:
|
||||||
|
self._pyisolate_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def _ensure_pyisolate_id(self):
|
||||||
|
pyisolate_id = getattr(self, "_pyisolate_id", None)
|
||||||
|
if pyisolate_id is None:
|
||||||
|
pyisolate_id = str(uuid.uuid4())
|
||||||
|
self._pyisolate_id = pyisolate_id
|
||||||
|
return pyisolate_id
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return self is other
|
||||||
|
if not isinstance(other, _HookRef):
|
||||||
|
return False
|
||||||
|
return self._ensure_pyisolate_id() == other._ensure_pyisolate_id()
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return id(self)
|
||||||
|
return hash(self._ensure_pyisolate_id())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if not _ISOLATION_HOOKREF_MODE:
|
||||||
|
return super().__str__()
|
||||||
|
return f"PYISOLATE_HOOKREF:{self._ensure_pyisolate_id()}"
|
||||||
|
|
||||||
|
|
||||||
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
|
||||||
@ -168,6 +200,8 @@ class WeightHook(Hook):
|
|||||||
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
|
||||||
else:
|
else:
|
||||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if self.weights is None:
|
||||||
|
self.weights = {}
|
||||||
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
|
||||||
else:
|
else:
|
||||||
if target == EnumWeightTarget.Clip:
|
if target == EnumWeightTarget.Clip:
|
||||||
|
|||||||
180
comfy/isolation/host_policy.py
Normal file
180
comfy/isolation/host_policy.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# pylint: disable=logging-fstring-interpolation
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from pathlib import PurePosixPath
|
||||||
|
from typing import Dict, List, TypedDict
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tomllib
|
||||||
|
except ImportError:
|
||||||
|
import tomli as tomllib # type: ignore[no-redef]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HOST_POLICY_PATH_ENV = "COMFY_HOST_POLICY_PATH"
|
||||||
|
VALID_SANDBOX_MODES = frozenset({"required", "disabled"})
|
||||||
|
FORBIDDEN_WRITABLE_PATHS = frozenset({"/tmp"})
|
||||||
|
|
||||||
|
|
||||||
|
class HostSecurityPolicy(TypedDict):
|
||||||
|
sandbox_mode: str
|
||||||
|
allow_network: bool
|
||||||
|
writable_paths: List[str]
|
||||||
|
readonly_paths: List[str]
|
||||||
|
sealed_worker_ro_import_paths: List[str]
|
||||||
|
whitelist: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_POLICY: HostSecurityPolicy = {
|
||||||
|
"sandbox_mode": "required",
|
||||||
|
"allow_network": False,
|
||||||
|
"writable_paths": ["/dev/shm"],
|
||||||
|
"readonly_paths": [],
|
||||||
|
"sealed_worker_ro_import_paths": [],
|
||||||
|
"whitelist": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _default_policy() -> HostSecurityPolicy:
|
||||||
|
return {
|
||||||
|
"sandbox_mode": DEFAULT_POLICY["sandbox_mode"],
|
||||||
|
"allow_network": DEFAULT_POLICY["allow_network"],
|
||||||
|
"writable_paths": list(DEFAULT_POLICY["writable_paths"]),
|
||||||
|
"readonly_paths": list(DEFAULT_POLICY["readonly_paths"]),
|
||||||
|
"sealed_worker_ro_import_paths": list(DEFAULT_POLICY["sealed_worker_ro_import_paths"]),
|
||||||
|
"whitelist": dict(DEFAULT_POLICY["whitelist"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_writable_paths(paths: list[object]) -> list[str]:
|
||||||
|
normalized_paths: list[str] = []
|
||||||
|
for raw_path in paths:
|
||||||
|
# Host-policy paths are contract-style POSIX paths; keep representation
|
||||||
|
# stable across Windows/Linux so tests and config behavior stay consistent.
|
||||||
|
normalized_path = str(PurePosixPath(str(raw_path).replace("\\", "/")))
|
||||||
|
if normalized_path in FORBIDDEN_WRITABLE_PATHS:
|
||||||
|
continue
|
||||||
|
normalized_paths.append(normalized_path)
|
||||||
|
return normalized_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _load_whitelist_file(file_path: Path, config_path: Path) -> Dict[str, str]:
|
||||||
|
if not file_path.is_absolute():
|
||||||
|
file_path = config_path.parent / file_path
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.warning("whitelist_file %s not found, skipping.", file_path)
|
||||||
|
return {}
|
||||||
|
entries: Dict[str, str] = {}
|
||||||
|
for line in file_path.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
entries[line] = "*"
|
||||||
|
logger.debug("Loaded %d whitelist entries from %s", len(entries), file_path)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_sealed_worker_ro_import_paths(raw_paths: object) -> list[str]:
|
||||||
|
if not isinstance(raw_paths, list):
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths must be a list of absolute paths."
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for raw_path in raw_paths:
|
||||||
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths entries must be non-empty strings."
|
||||||
|
)
|
||||||
|
normalized_path = str(PurePosixPath(raw_path.replace("\\", "/")))
|
||||||
|
# Accept both POSIX absolute paths (/home/...) and Windows drive-letter paths (D:/...)
|
||||||
|
is_absolute = normalized_path.startswith("/") or (
|
||||||
|
len(normalized_path) >= 3 and normalized_path[1] == ":" and normalized_path[2] == "/"
|
||||||
|
)
|
||||||
|
if not is_absolute:
|
||||||
|
raise ValueError(
|
||||||
|
"tool.comfy.host.sealed_worker_ro_import_paths entries must be absolute paths."
|
||||||
|
)
|
||||||
|
if normalized_path not in seen:
|
||||||
|
seen.add(normalized_path)
|
||||||
|
normalized_paths.append(normalized_path)
|
||||||
|
|
||||||
|
return normalized_paths
|
||||||
|
|
||||||
|
|
||||||
|
def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||||
|
config_override = os.environ.get(HOST_POLICY_PATH_ENV)
|
||||||
|
config_path = Path(config_override) if config_override else comfy_root / "pyproject.toml"
|
||||||
|
policy = _default_policy()
|
||||||
|
|
||||||
|
if not config_path.exists():
|
||||||
|
logger.debug("Host policy file missing at %s, using defaults.", config_path)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
try:
|
||||||
|
with config_path.open("rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse host policy from %s, using defaults.",
|
||||||
|
config_path,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
tool_config = data.get("tool", {}).get("comfy", {}).get("host", {})
|
||||||
|
if not isinstance(tool_config, dict):
|
||||||
|
logger.debug("No [tool.comfy.host] section found, using defaults.")
|
||||||
|
return policy
|
||||||
|
|
||||||
|
sandbox_mode = tool_config.get("sandbox_mode")
|
||||||
|
if isinstance(sandbox_mode, str):
|
||||||
|
normalized_sandbox_mode = sandbox_mode.strip().lower()
|
||||||
|
if normalized_sandbox_mode in VALID_SANDBOX_MODES:
|
||||||
|
policy["sandbox_mode"] = normalized_sandbox_mode
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Invalid host sandbox_mode %r in %s, using default %r.",
|
||||||
|
sandbox_mode,
|
||||||
|
config_path,
|
||||||
|
DEFAULT_POLICY["sandbox_mode"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "allow_network" in tool_config:
|
||||||
|
policy["allow_network"] = bool(tool_config["allow_network"])
|
||||||
|
|
||||||
|
if "writable_paths" in tool_config:
|
||||||
|
policy["writable_paths"] = _normalize_writable_paths(tool_config["writable_paths"])
|
||||||
|
|
||||||
|
if "readonly_paths" in tool_config:
|
||||||
|
policy["readonly_paths"] = [str(p) for p in tool_config["readonly_paths"]]
|
||||||
|
|
||||||
|
if "sealed_worker_ro_import_paths" in tool_config:
|
||||||
|
policy["sealed_worker_ro_import_paths"] = _normalize_sealed_worker_ro_import_paths(
|
||||||
|
tool_config["sealed_worker_ro_import_paths"]
|
||||||
|
)
|
||||||
|
|
||||||
|
whitelist_file = tool_config.get("whitelist_file")
|
||||||
|
if isinstance(whitelist_file, str):
|
||||||
|
policy["whitelist"].update(_load_whitelist_file(Path(whitelist_file), config_path))
|
||||||
|
|
||||||
|
whitelist_raw = tool_config.get("whitelist")
|
||||||
|
if isinstance(whitelist_raw, dict):
|
||||||
|
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||||
|
|
||||||
|
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||||
|
len(policy["whitelist"]),
|
||||||
|
policy["sandbox_mode"],
|
||||||
|
policy["allow_network"],
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["HostSecurityPolicy", "load_host_policy", "DEFAULT_POLICY"]
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
@ -12,8 +13,8 @@ from . import deis
|
|||||||
from . import sa_solver
|
from . import sa_solver
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
from comfy.cli_args import args
|
||||||
from comfy.utils import model_trange as trange
|
from comfy.utils import model_trange as trange
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
@ -191,6 +192,13 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
if isolation_active:
|
||||||
|
target_device = sigmas.device
|
||||||
|
if x.device != target_device:
|
||||||
|
x = x.to(target_device)
|
||||||
|
s_in = s_in.to(target_device)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
if s_churn > 0:
|
if s_churn > 0:
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
|
|||||||
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import comfy.ldm.lightricks.av_model
|
import comfy.ldm.lightricks.av_model
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@ -120,8 +121,20 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.V_PREDICTION_DDPM:
|
elif model_type == ModelType.V_PREDICTION_DDPM:
|
||||||
c = comfy.model_sampling.V_PREDICTION_DDPM
|
c = comfy.model_sampling.V_PREDICTION_DDPM
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
if isolation_runtime_enabled:
|
||||||
|
def __reduce__(self):
|
||||||
|
"""Ensure pickling yields a proxy instead of failing on local class."""
|
||||||
|
try:
|
||||||
|
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
|
||||||
|
registry = ModelSamplingRegistry()
|
||||||
|
ms_id = registry.register(self)
|
||||||
|
return (ModelSamplingProxy, (ms_id,))
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
|
||||||
|
|
||||||
return ModelSampling(model_config)
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|||||||
@ -498,6 +498,9 @@ except:
|
|||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
|
def _isolation_mode_enabled():
|
||||||
|
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
|
|
||||||
def module_size(module):
|
def module_size(module):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
sd = module.state_dict()
|
sd = module.state_dict()
|
||||||
@ -604,8 +607,9 @@ class LoadedModel:
|
|||||||
if freed >= memory_to_free:
|
if freed >= memory_to_free:
|
||||||
return False
|
return False
|
||||||
self.model.detach(unpatch_weights)
|
self.model.detach(unpatch_weights)
|
||||||
self.model_finalizer.detach()
|
if self.model_finalizer is not None:
|
||||||
self.model_finalizer = None
|
self.model_finalizer.detach()
|
||||||
|
self.model_finalizer = None
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -619,8 +623,15 @@ class LoadedModel:
|
|||||||
if self._patcher_finalizer is not None:
|
if self._patcher_finalizer is not None:
|
||||||
self._patcher_finalizer.detach()
|
self._patcher_finalizer.detach()
|
||||||
|
|
||||||
|
def dead_state(self):
|
||||||
|
model_ref_gone = self.model is None
|
||||||
|
real_model_ref = self.real_model
|
||||||
|
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
|
||||||
|
return model_ref_gone, real_model_ref_gone
|
||||||
|
|
||||||
def is_dead(self):
|
def is_dead(self):
|
||||||
return self.real_model() is not None and self.model is None
|
model_ref_gone, real_model_ref_gone = self.dead_state()
|
||||||
|
return model_ref_gone or real_model_ref_gone
|
||||||
|
|
||||||
|
|
||||||
def use_more_memory(extra_memory, loaded_models, device):
|
def use_more_memory(extra_memory, loaded_models, device):
|
||||||
@ -667,6 +678,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
|
isolation_active = _isolation_mode_enabled()
|
||||||
|
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
@ -675,6 +687,17 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
|
if can_unload and isolation_active:
|
||||||
|
try:
|
||||||
|
from pyisolate import flush_tensor_keeper # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
flush_tensor_keeper = None
|
||||||
|
if callable(flush_tensor_keeper):
|
||||||
|
flushed = flush_tensor_keeper()
|
||||||
|
if flushed > 0:
|
||||||
|
logging.debug("][ MM:tensor_keeper_flush | released=%d", flushed)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
can_unload_sorted = sorted(can_unload)
|
can_unload_sorted = sorted(can_unload)
|
||||||
for x in can_unload_sorted:
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
@ -705,7 +728,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded = current_loaded_models.pop(i)
|
||||||
|
model_obj = unloaded.model
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
|
unloaded_models.append(unloaded)
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
@ -764,7 +793,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
model_to_unload = current_loaded_models.pop(i)
|
model_to_unload = current_loaded_models.pop(i)
|
||||||
model_to_unload.model.detach(unpatch_all=False)
|
model_to_unload.model.detach(unpatch_all=False)
|
||||||
model_to_unload.model_finalizer.detach()
|
if model_to_unload.model_finalizer is not None:
|
||||||
|
model_to_unload.model_finalizer.detach()
|
||||||
|
model_to_unload.model_finalizer = None
|
||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
@ -837,25 +868,62 @@ def loaded_models(only_currently_used=False):
|
|||||||
|
|
||||||
|
|
||||||
def cleanup_models_gc():
|
def cleanup_models_gc():
|
||||||
do_gc = False
|
|
||||||
|
|
||||||
reset_cast_buffers()
|
reset_cast_buffers()
|
||||||
|
if not _isolation_mode_enabled():
|
||||||
|
dead_found = False
|
||||||
|
for i in range(len(current_loaded_models)):
|
||||||
|
if current_loaded_models[i].is_dead():
|
||||||
|
dead_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if dead_found:
|
||||||
|
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||||
|
gc.collect()
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
|
cur = current_loaded_models[i]
|
||||||
|
if cur.is_dead():
|
||||||
|
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||||
|
leaked = current_loaded_models.pop(i)
|
||||||
|
model_obj = getattr(leaked, "model", None)
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_found = False
|
||||||
|
has_real_model_leak = False
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
cur = current_loaded_models[i]
|
model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
|
||||||
if cur.is_dead():
|
if model_ref_gone or real_model_ref_gone:
|
||||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
dead_found = True
|
||||||
do_gc = True
|
if real_model_ref_gone and not model_ref_gone:
|
||||||
break
|
has_real_model_leak = True
|
||||||
|
|
||||||
if do_gc:
|
if dead_found:
|
||||||
|
if has_real_model_leak:
|
||||||
|
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
|
||||||
|
else:
|
||||||
|
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
cur = current_loaded_models[i]
|
cur = current_loaded_models[i]
|
||||||
if cur.is_dead():
|
model_ref_gone, real_model_ref_gone = cur.dead_state()
|
||||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
if model_ref_gone or real_model_ref_gone:
|
||||||
|
if real_model_ref_gone and not model_ref_gone:
|
||||||
|
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
|
||||||
|
else:
|
||||||
|
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
|
||||||
|
leaked = current_loaded_models.pop(i)
|
||||||
|
model_obj = getattr(leaked, "model", None)
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
def archive_model_dtypes(model):
|
def archive_model_dtypes(model):
|
||||||
@ -869,11 +937,20 @@ def archive_model_dtypes(model):
|
|||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if current_loaded_models[i].real_model() is None:
|
real_model_ref = current_loaded_models[i].real_model
|
||||||
|
if real_model_ref is None:
|
||||||
|
to_delete = [i] + to_delete
|
||||||
|
continue
|
||||||
|
if callable(real_model_ref) and real_model_ref() is None:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
|
|
||||||
for i in to_delete:
|
for i in to_delete:
|
||||||
x = current_loaded_models.pop(i)
|
x = current_loaded_models.pop(i)
|
||||||
|
model_obj = getattr(x, "model", None)
|
||||||
|
if model_obj is not None:
|
||||||
|
cleanup = getattr(model_obj, "cleanup", None)
|
||||||
|
if callable(cleanup):
|
||||||
|
cleanup()
|
||||||
del x
|
del x
|
||||||
|
|
||||||
def dtype_size(dtype):
|
def dtype_size(dtype):
|
||||||
|
|||||||
@ -11,12 +11,14 @@ from functools import partial
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import comfy.sampler_helpers
|
import comfy.sampler_helpers
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy.cli_args import args
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -210,9 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
|||||||
_calc_cond_batch,
|
_calc_cond_batch,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
result = executor.execute(model, conds, x_in, timestep, model_options)
|
||||||
|
return result
|
||||||
|
|
||||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
out_conds = []
|
out_conds = []
|
||||||
out_counts = []
|
out_counts = []
|
||||||
# separate conds by matching hooks
|
# separate conds by matching hooks
|
||||||
@ -269,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
for k, v in to_run[tt][0].conditioning.items():
|
for k, v in to_run[tt][0].conditioning.items():
|
||||||
cond_shapes[k].append(v.size())
|
cond_shapes[k].append(v.size())
|
||||||
|
|
||||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
|
||||||
|
if memory_required * 1.5 < free_memory:
|
||||||
to_batch = batch_amount
|
to_batch = batch_amount
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -294,9 +299,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
patches = p.patches
|
patches = p.patches
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
if isolation_active:
|
||||||
|
target_device = model.load_device if hasattr(model, "load_device") else input_x[0].device
|
||||||
|
input_x = torch.cat(input_x).to(target_device)
|
||||||
|
else:
|
||||||
|
input_x = torch.cat(input_x)
|
||||||
c = cond_cat(c)
|
c = cond_cat(c)
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
if isolation_active:
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks).to(target_device)
|
||||||
|
mult = [m.to(target_device) if hasattr(m, "to") else m for m in mult]
|
||||||
|
else:
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
|
||||||
if 'transformer_options' in model_options:
|
if 'transformer_options' in model_options:
|
||||||
@ -327,9 +340,17 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
cond_index = cond_or_uncond[o]
|
cond_index = cond_or_uncond[o]
|
||||||
a = area[o]
|
a = area[o]
|
||||||
|
out_t = output[o]
|
||||||
|
mult_t = mult[o]
|
||||||
|
if isolation_active:
|
||||||
|
target_dev = out_conds[cond_index].device
|
||||||
|
if hasattr(out_t, "device") and out_t.device != target_dev:
|
||||||
|
out_t = out_t.to(target_dev)
|
||||||
|
if hasattr(mult_t, "device") and mult_t.device != target_dev:
|
||||||
|
mult_t = mult_t.to(target_dev)
|
||||||
if a is None:
|
if a is None:
|
||||||
out_conds[cond_index] += output[o] * mult[o]
|
out_conds[cond_index] += out_t * mult_t
|
||||||
out_counts[cond_index] += mult[o]
|
out_counts[cond_index] += mult_t
|
||||||
else:
|
else:
|
||||||
out_c = out_conds[cond_index]
|
out_c = out_conds[cond_index]
|
||||||
out_cts = out_counts[cond_index]
|
out_cts = out_counts[cond_index]
|
||||||
@ -337,8 +358,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
for i in range(dims):
|
for i in range(dims):
|
||||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||||
out_c += output[o] * mult[o]
|
out_c += out_t * mult_t
|
||||||
out_cts += mult[o]
|
out_cts += mult_t
|
||||||
|
|
||||||
for i in range(len(out_conds)):
|
for i in range(len(out_conds)):
|
||||||
out_conds[i] /= out_counts[i]
|
out_conds[i] /= out_counts[i]
|
||||||
@ -392,14 +413,31 @@ class KSamplerX0Inpaint:
|
|||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.sigmas = sigmas
|
self.sigmas = sigmas
|
||||||
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
|
||||||
|
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
|
if isolation_active and denoise_mask.device != x.device:
|
||||||
|
denoise_mask = denoise_mask.to(x.device)
|
||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
|
if isolation_active:
|
||||||
|
latent_image = self.latent_image
|
||||||
|
if hasattr(latent_image, "device") and latent_image.device != x.device:
|
||||||
|
latent_image = latent_image.to(x.device)
|
||||||
|
scaled = self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=latent_image)
|
||||||
|
if hasattr(scaled, "device") and scaled.device != x.device:
|
||||||
|
scaled = scaled.to(x.device)
|
||||||
|
else:
|
||||||
|
scaled = self.inner_model.inner_model.scale_latent_inpaint(
|
||||||
|
x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image
|
||||||
|
)
|
||||||
|
x = x * denoise_mask + scaled * latent_mask
|
||||||
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out = out * denoise_mask + self.latent_image * latent_mask
|
latent_image = self.latent_image
|
||||||
|
if isolation_active and hasattr(latent_image, "device") and latent_image.device != out.device:
|
||||||
|
latent_image = latent_image.to(out.device)
|
||||||
|
out = out * denoise_mask + latent_image * latent_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model_sampling, steps):
|
def simple_scheduler(model_sampling, steps):
|
||||||
@ -741,7 +779,11 @@ class KSAMPLER(Sampler):
|
|||||||
else:
|
else:
|
||||||
model_k.noise = noise
|
model_k.noise = noise
|
||||||
|
|
||||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas))
|
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||||
|
model_sampling = model_wrap.inner_model.model_sampling
|
||||||
|
noise = model_sampling.noise_scaling(
|
||||||
|
sigmas[0], noise, latent_image, max_denoise
|
||||||
|
)
|
||||||
|
|
||||||
k_callback = None
|
k_callback = None
|
||||||
total_steps = len(sigmas) - 1
|
total_steps = len(sigmas) - 1
|
||||||
|
|||||||
@ -92,7 +92,7 @@ if args.cuda_malloc:
|
|||||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
if env_var is None:
|
if env_var is None:
|
||||||
env_var = "backend:cudaMallocAsync"
|
env_var = "backend:cudaMallocAsync"
|
||||||
else:
|
elif not args.use_process_isolation:
|
||||||
env_var += ",backend:cudaMallocAsync"
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user