feat(isolation): DynamicVRAM compatibility for process isolation

DynamicVRAM's on-demand model loading/offloading conflicted with  process isolation in three ways: RPC tensor transport stalls from mid-call GPU offload, race conditions between model lifecycle and active RPC operations, and false positive memory leak detection from changed finalizer patterns.

- Marshal CUDA tensors to CPU before RPC transport for dynamic models
- Add operation state tracking + quiescence waits at workflow boundaries
- Distinguish proxy reference release from actual leaks in cleanup_models_gc
- Fix init order: DynamicVRAM must initialize before isolation proxies
- Add RPC timeouts to prevent indefinite hangs on model unavailability
- Prevent proxy-of-proxy chains from DynamicVRAM model reload cycles
- Add torch.device/torch.dtype serializers for new DynamicVRAM RPC paths
- Guard isolation overhead so non-isolated workflows are unaffected
- Migrate env var to PYISOLATE_CHILD
This commit is contained in:
John Pollock 2026-03-04 23:48:02 -06:00
parent a0f8784e9f
commit 9250191c65
38 changed files with 94595 additions and 307 deletions

View File

@ -0,0 +1 @@
f03e4c88e21504c3

View File

@ -0,0 +1,81 @@
{
"DepthAnything_V2": {
"input_types": {
"required": {
"da_model": {
"__pyisolate_tuple__": [
"DAMODEL"
]
},
"images": {
"__pyisolate_tuple__": [
"IMAGE"
]
}
}
},
"return_types": [
"IMAGE"
],
"return_names": [
"image"
],
"function": "process",
"category": "DepthAnythingV2",
"output_node": false,
"output_is_list": null,
"is_v3": false,
"display_name": "Depth Anything V2"
},
"DownloadAndLoadDepthAnythingV2Model": {
"input_types": {
"required": {
"model": {
"__pyisolate_tuple__": [
[
"depth_anything_v2_vits_fp16.safetensors",
"depth_anything_v2_vits_fp32.safetensors",
"depth_anything_v2_vitb_fp16.safetensors",
"depth_anything_v2_vitb_fp32.safetensors",
"depth_anything_v2_vitl_fp16.safetensors",
"depth_anything_v2_vitl_fp32.safetensors",
"depth_anything_v2_vitg_fp32.safetensors",
"depth_anything_v2_metric_hypersim_vitl_fp32.safetensors",
"depth_anything_v2_metric_vkitti_vitl_fp32.safetensors"
],
{
"default": "depth_anything_v2_vitl_fp32.safetensors"
}
]
}
},
"optional": {
"precision": {
"__pyisolate_tuple__": [
[
"auto",
"bf16",
"fp16",
"fp32"
],
{
"default": "auto"
}
]
}
}
},
"return_types": [
"DAMODEL"
],
"return_names": [
"da_v2_model"
],
"function": "loadmodel",
"category": "DepthAnythingV2",
"output_node": false,
"output_is_list": null,
"is_v3": false,
"display_name": "DownloadAndLoadDepthAnythingV2Model"
}
}

View File

@ -0,0 +1 @@
4b90e6876f4c0b8c

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@ from importlib.metadata import version
import requests import requests
from typing_extensions import NotRequired from typing_extensions import NotRequired
from utils.install_util import get_missing_requirements_message, requirements_path from utils.install_util import get_missing_requirements_message, get_required_packages_versions
from comfy.cli_args import DEFAULT_VERSION_STRING from comfy.cli_args import DEFAULT_VERSION_STRING
import app.logger import app.logger
@ -45,25 +45,7 @@ def get_installed_frontend_version():
def get_required_frontend_version(): def get_required_frontend_version():
"""Get the required frontend version from requirements.txt.""" return get_required_packages_versions().get("comfyui-frontend-package", None)
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-frontend-package=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-frontend-package not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required frontend version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
def check_frontend_version(): def check_frontend_version():
@ -217,25 +199,7 @@ class FrontendManager:
@classmethod @classmethod
def get_required_templates_version(cls) -> str: def get_required_templates_version(cls) -> str:
"""Get the required workflow templates version from requirements.txt.""" return get_required_packages_versions().get("comfyui-workflow-templates", None)
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("comfyui-workflow-templates=="):
version_str = line.split("==")[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
return None
return version_str
logging.error("comfyui-workflow-templates not found in requirements.txt")
return None
except FileNotFoundError:
logging.error("requirements.txt not found. Cannot determine required templates version.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None
@classmethod @classmethod
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:

View File

@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult" Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops" CublasOps = "cublas_ops"
AutoTune = "autotune" AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@ -262,4 +262,4 @@ else:
args.fast = set(args.fast) args.fast = set(args.fast)
def enables_dynamic_vram(): def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask) matches = torch.nonzero(mask)
if torch.numel(matches) == 0: if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.") return # substep from multi-step sampler: keep self._step from the last full step
self._step = int(matches[0].item()) self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:

View File

@ -64,7 +64,7 @@ class EnumHookScope(enum.Enum):
HookedOnly = "hooked_only" HookedOnly = "hooked_only"
_ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" _ISOLATION_HOOKREF_MODE = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class _HookRef: class _HookRef:

View File

@ -26,6 +26,7 @@ PYISOLATE_VENV_ROOT.mkdir(parents=True, exist_ok=True)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024 _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES = 2 * 1024 * 1024 * 1024
_MODEL_PATCHER_IDLE_TIMEOUT_MS = 120000
def initialize_proxies() -> None: def initialize_proxies() -> None:
@ -168,6 +169,11 @@ def _get_class_types_for_extension(extension_name: str) -> Set[str]:
async def notify_execution_graph(needed_class_types: Set[str]) -> None: async def notify_execution_graph(needed_class_types: Set[str]) -> None:
"""Evict running extensions not needed for current execution.""" """Evict running extensions not needed for current execution."""
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:notify_graph_wait_idle",
)
async def _stop_extension( async def _stop_extension(
ext_name: str, extension: "ComfyNodeExtension", reason: str ext_name: str, extension: "ComfyNodeExtension", reason: str
@ -182,22 +188,33 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
scan_shm_forensics("ISO:stop_extension", refresh_model_context=True) scan_shm_forensics("ISO:stop_extension", refresh_model_context=True)
scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True) scan_shm_forensics("ISO:notify_graph_start", refresh_model_context=True)
isolated_class_types_in_graph = needed_class_types.intersection(
{spec.node_name for spec in _ISOLATED_NODE_SPECS}
)
graph_uses_isolation = bool(isolated_class_types_in_graph)
logger.debug( logger.debug(
"%s ISO:notify_graph_start running=%d needed=%d", "%s ISO:notify_graph_start running=%d needed=%d",
LOG_PREFIX, LOG_PREFIX,
len(_RUNNING_EXTENSIONS), len(_RUNNING_EXTENSIONS),
len(needed_class_types), len(needed_class_types),
) )
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): if graph_uses_isolation:
ext_class_types = _get_class_types_for_extension(ext_name) for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
ext_class_types = _get_class_types_for_extension(ext_name)
# If NONE of this extension's nodes are in the execution graph → evict # If NONE of this extension's nodes are in the execution graph -> evict.
if not ext_class_types.intersection(needed_class_types): if not ext_class_types.intersection(needed_class_types):
await _stop_extension( await _stop_extension(
ext_name, ext_name,
extension, extension,
"isolated custom_node not in execution graph, evicting", "isolated custom_node not in execution graph, evicting",
) )
else:
logger.debug(
"%s ISO:notify_graph_skip_evict running=%d reason=no isolated nodes in graph",
LOG_PREFIX,
len(_RUNNING_EXTENSIONS),
)
# Isolated child processes add steady VRAM pressure; reclaim host-side models # Isolated child processes add steady VRAM pressure; reclaim host-side models
# at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom. # at workflow boundaries so subsequent host nodes (e.g. CLIP encode) keep headroom.
@ -211,7 +228,7 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
_WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES, _WORKFLOW_BOUNDARY_MIN_FREE_VRAM_BYTES,
) )
free_before = model_management.get_free_memory(device) free_before = model_management.get_free_memory(device)
if free_before < required and _RUNNING_EXTENSIONS: if free_before < required and _RUNNING_EXTENSIONS and graph_uses_isolation:
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
await _stop_extension( await _stop_extension(
ext_name, ext_name,
@ -237,6 +254,11 @@ async def notify_execution_graph(needed_class_types: Set[str]) -> None:
async def flush_running_extensions_transport_state() -> int: async def flush_running_extensions_transport_state() -> int:
await wait_for_model_patcher_quiescence(
timeout_ms=_MODEL_PATCHER_IDLE_TIMEOUT_MS,
fail_loud=True,
marker="ISO:flush_transport_wait_idle",
)
total_flushed = 0 total_flushed = 0
for ext_name, extension in list(_RUNNING_EXTENSIONS.items()): for ext_name, extension in list(_RUNNING_EXTENSIONS.items()):
flush_fn = getattr(extension, "flush_transport_state", None) flush_fn = getattr(extension, "flush_transport_state", None)
@ -263,6 +285,50 @@ async def flush_running_extensions_transport_state() -> int:
return total_flushed return total_flushed
async def wait_for_model_patcher_quiescence(
timeout_ms: int = _MODEL_PATCHER_IDLE_TIMEOUT_MS,
*,
fail_loud: bool = False,
marker: str = "ISO:wait_model_patcher_idle",
) -> bool:
try:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
registry = ModelPatcherRegistry()
start = time.perf_counter()
idle = await registry.wait_all_idle(timeout_ms)
elapsed_ms = (time.perf_counter() - start) * 1000.0
if idle:
logger.debug(
"%s %s idle=1 timeout_ms=%d elapsed_ms=%.3f",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
)
return True
states = await registry.get_all_operation_states()
logger.error(
"%s %s idle_timeout timeout_ms=%d elapsed_ms=%.3f states=%s",
LOG_PREFIX,
marker,
timeout_ms,
elapsed_ms,
states,
)
if fail_loud:
raise TimeoutError(
f"ModelPatcherRegistry did not quiesce within {timeout_ms} ms"
)
return False
except Exception:
if fail_loud:
raise
logger.debug("%s %s failed", LOG_PREFIX, marker, exc_info=True)
return False
def get_claimed_paths() -> Set[Path]: def get_claimed_paths() -> Set[Path]:
return _CLAIMED_PATHS return _CLAIMED_PATHS
@ -320,6 +386,7 @@ __all__ = [
"await_isolation_loading", "await_isolation_loading",
"notify_execution_graph", "notify_execution_graph",
"flush_running_extensions_transport_state", "flush_running_extensions_transport_state",
"wait_for_model_patcher_quiescence",
"get_claimed_paths", "get_claimed_paths",
"update_rpc_event_loops", "update_rpc_event_loops",
"IsolatedNodeSpec", "IsolatedNodeSpec",

View File

@ -83,6 +83,33 @@ class ComfyUIAdapter(IsolationAdapter):
logging.getLogger(pkg_name).setLevel(logging.ERROR) logging.getLogger(pkg_name).setLevel(logging.ERROR)
def register_serializers(self, registry: SerializerRegistryProtocol) -> None: def register_serializers(self, registry: SerializerRegistryProtocol) -> None:
import torch
def serialize_device(obj: Any) -> Dict[str, Any]:
return {"__type__": "device", "device_str": str(obj)}
def deserialize_device(data: Dict[str, Any]) -> Any:
return torch.device(data["device_str"])
registry.register("device", serialize_device, deserialize_device)
_VALID_DTYPES = {
"float16", "float32", "float64", "bfloat16",
"int8", "int16", "int32", "int64",
"uint8", "bool",
}
def serialize_dtype(obj: Any) -> Dict[str, Any]:
return {"__type__": "dtype", "dtype_str": str(obj)}
def deserialize_dtype(data: Dict[str, Any]) -> Any:
dtype_name = data["dtype_str"].replace("torch.", "")
if dtype_name not in _VALID_DTYPES:
raise ValueError(f"Invalid dtype: {data['dtype_str']}")
return getattr(torch, dtype_name)
registry.register("dtype", serialize_dtype, deserialize_dtype)
def serialize_model_patcher(obj: Any) -> Dict[str, Any]: def serialize_model_patcher(obj: Any) -> Dict[str, Any]:
# Child-side: must already have _instance_id (proxy) # Child-side: must already have _instance_id (proxy)
if os.environ.get("PYISOLATE_CHILD") == "1": if os.environ.get("PYISOLATE_CHILD") == "1":
@ -193,6 +220,10 @@ class ComfyUIAdapter(IsolationAdapter):
f"ModelSampling in child lacks _instance_id: " f"ModelSampling in child lacks _instance_id: "
f"{type(obj).__module__}.{type(obj).__name__}" f"{type(obj).__module__}.{type(obj).__name__}"
) )
# Host-side pass-through for proxies: do not re-register a proxy as a
# new ModelSamplingRef, or we create proxy-of-proxy indirection.
if hasattr(obj, "_instance_id"):
return {"__type__": "ModelSamplingRef", "ms_id": obj._instance_id}
# Host-side: register with ModelSamplingRegistry and return JSON-safe dict # Host-side: register with ModelSamplingRegistry and return JSON-safe dict
ms_id = ModelSamplingRegistry().register(obj) ms_id = ModelSamplingRegistry().register(obj)
return {"__type__": "ModelSamplingRef", "ms_id": ms_id} return {"__type__": "ModelSamplingRef", "ms_id": ms_id}
@ -211,22 +242,21 @@ class ComfyUIAdapter(IsolationAdapter):
else: else:
return ModelSamplingRegistry()._get_instance(data["ms_id"]) return ModelSamplingRegistry()._get_instance(data["ms_id"])
# Register ModelSampling type and proxy # Register all ModelSampling* and StableCascadeSampling classes dynamically
registry.register( import comfy.model_sampling
"ModelSamplingDiscrete",
serialize_model_sampling, for ms_cls in vars(comfy.model_sampling).values():
deserialize_model_sampling, if not isinstance(ms_cls, type):
) continue
registry.register( if not issubclass(ms_cls, torch.nn.Module):
"ModelSamplingContinuousEDM", continue
serialize_model_sampling, if not (ms_cls.__name__.startswith("ModelSampling") or ms_cls.__name__ == "StableCascadeSampling"):
deserialize_model_sampling, continue
) registry.register(
registry.register( ms_cls.__name__,
"ModelSamplingContinuousV", serialize_model_sampling,
serialize_model_sampling, deserialize_model_sampling,
deserialize_model_sampling, )
)
registry.register( registry.register(
"ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling "ModelSamplingProxy", serialize_model_sampling, deserialize_model_sampling
) )

View File

@ -389,7 +389,8 @@ class ComfyNodeExtension(ExtensionBase):
getattr(self, "name", "?"), getattr(self, "name", "?"),
node_name, node_name,
) )
return self._wrap_unpicklable_objects(result) wrapped = self._wrap_unpicklable_objects(result)
return wrapped
if not isinstance(result, tuple): if not isinstance(result, tuple):
result = (result,) result = (result,)
@ -400,7 +401,8 @@ class ComfyNodeExtension(ExtensionBase):
node_name, node_name,
len(result), len(result),
) )
return self._wrap_unpicklable_objects(result) wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_transport_state(self) -> int: async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1": if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
@ -443,7 +445,10 @@ class ComfyNodeExtension(ExtensionBase):
if isinstance(data, (str, int, float, bool, type(None))): if isinstance(data, (str, int, float, bool, type(None))):
return data return data
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.detach() if data.requires_grad else data tensor = data.detach() if data.requires_grad else data
if os.environ.get("PYISOLATE_CHILD") == "1" and tensor.device.type != "cpu":
return tensor.cpu()
return tensor
# Special-case clip vision outputs: preserve attribute access by packing fields # Special-case clip vision outputs: preserve attribute access by packing fields
if hasattr(data, "penultimate_hidden_states") or hasattr( if hasattr(data, "penultimate_hidden_states") or hasattr(

View File

@ -338,6 +338,34 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def apply_model(self, *args, **kwargs) -> Any: def apply_model(self, *args, **kwargs) -> Any:
import torch import torch
def _preferred_device() -> Any:
for value in args:
if isinstance(value, torch.Tensor):
return value.device
for value in kwargs.values():
if isinstance(value, torch.Tensor):
return value.device
return None
def _move_result_to_device(obj: Any, device: Any) -> Any:
if device is None:
return obj
if isinstance(obj, torch.Tensor):
return obj.to(device) if obj.device != device else obj
if isinstance(obj, dict):
return {k: _move_result_to_device(v, device) for k, v in obj.items()}
if isinstance(obj, list):
return [_move_result_to_device(v, device) for v in obj]
if isinstance(obj, tuple):
return tuple(_move_result_to_device(v, device) for v in obj)
return obj
# DynamicVRAM models must keep load/offload decisions in host process.
# Child-side CUDA staging here can deadlock before first inference RPC.
if self.is_dynamic():
out = self._call_rpc("inner_model_apply_model", args, kwargs)
return _move_result_to_device(out, _preferred_device())
required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs) required_bytes = self._cpu_tensor_bytes(args) + self._cpu_tensor_bytes(kwargs)
self._ensure_apply_model_headroom(required_bytes) self._ensure_apply_model_headroom(required_bytes)
@ -360,7 +388,8 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
args_cuda = _to_cuda(args) args_cuda = _to_cuda(args)
kwargs_cuda = _to_cuda(kwargs) kwargs_cuda = _to_cuda(kwargs)
return self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda) out = self._call_rpc("inner_model_apply_model", args_cuda, kwargs_cuda)
return _move_result_to_device(out, _preferred_device())
def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any: def model_state_dict(self, filter_prefix: Optional[str] = None) -> Any:
keys = self._call_rpc("model_state_dict", filter_prefix) keys = self._call_rpc("model_state_dict", filter_prefix)
@ -526,6 +555,13 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
def memory_required(self, input_shape: Any) -> Any: def memory_required(self, input_shape: Any) -> Any:
return self._call_rpc("memory_required", input_shape) return self._call_rpc("memory_required", input_shape)
def get_operation_state(self) -> Dict[str, Any]:
state = self._call_rpc("get_operation_state")
return state if isinstance(state, dict) else {}
def wait_for_idle(self, timeout_ms: int = 0) -> bool:
return bool(self._call_rpc("wait_for_idle", timeout_ms))
def is_dynamic(self) -> bool: def is_dynamic(self) -> bool:
return bool(self._call_rpc("is_dynamic")) return bool(self._call_rpc("is_dynamic"))
@ -771,6 +807,7 @@ class ModelPatcherProxy(BaseProxy[ModelPatcherRegistry]):
class _InnerModelProxy: class _InnerModelProxy:
def __init__(self, parent: ModelPatcherProxy): def __init__(self, parent: ModelPatcherProxy):
self._parent = parent self._parent = parent
self._model_sampling = None
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name.startswith("_"): if name.startswith("_"):
@ -793,7 +830,11 @@ class _InnerModelProxy:
manage_lifecycle=False, manage_lifecycle=False,
) )
if name == "model_sampling": if name == "model_sampling":
return self._parent._call_rpc("get_model_object", "model_sampling") if self._model_sampling is None:
self._model_sampling = self._parent._call_rpc(
"get_model_object", "model_sampling"
)
return self._model_sampling
if name == "extra_conds_shapes": if name == "extra_conds_shapes":
return lambda *a, **k: self._parent._call_rpc( return lambda *a, **k: self._parent._call_rpc(
"inner_model_extra_conds_shapes", a, k "inner_model_extra_conds_shapes", a, k

View File

@ -2,8 +2,12 @@
# RPC server for ModelPatcher isolation (child process) # RPC server for ModelPatcher isolation (child process)
from __future__ import annotations from __future__ import annotations
import asyncio
import gc import gc
import logging import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional, List from typing import Any, Optional, List
try: try:
@ -43,12 +47,191 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class _OperationState:
lease: threading.Lock = field(default_factory=threading.Lock)
active_count: int = 0
active_by_method: dict[str, int] = field(default_factory=dict)
total_operations: int = 0
last_method: Optional[str] = None
last_started_ts: Optional[float] = None
last_ended_ts: Optional[float] = None
last_elapsed_ms: Optional[float] = None
last_error: Optional[str] = None
last_thread_id: Optional[int] = None
last_loop_id: Optional[int] = None
class ModelPatcherRegistry(BaseRegistry[Any]): class ModelPatcherRegistry(BaseRegistry[Any]):
_type_prefix = "model" _type_prefix = "model"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._pending_cleanup_ids: set[str] = set() self._pending_cleanup_ids: set[str] = set()
self._operation_states: dict[str, _OperationState] = {}
self._operation_state_cv = threading.Condition(self._lock)
def _get_or_create_operation_state(self, instance_id: str) -> _OperationState:
state = self._operation_states.get(instance_id)
if state is None:
state = _OperationState()
self._operation_states[instance_id] = state
return state
def _begin_operation(self, instance_id: str, method_name: str) -> tuple[float, float]:
start_epoch = time.time()
start_perf = time.perf_counter()
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count += 1
state.active_by_method[method_name] = (
state.active_by_method.get(method_name, 0) + 1
)
state.total_operations += 1
state.last_method = method_name
state.last_started_ts = start_epoch
state.last_thread_id = threading.get_ident()
try:
state.last_loop_id = id(asyncio.get_running_loop())
except RuntimeError:
state.last_loop_id = None
logger.debug(
"ISO:registry_op_start instance_id=%s method=%s start_ts=%.6f thread=%s loop=%s",
instance_id,
method_name,
start_epoch,
threading.get_ident(),
state.last_loop_id,
)
return start_epoch, start_perf
def _end_operation(
self,
instance_id: str,
method_name: str,
start_perf: float,
error: Optional[BaseException] = None,
) -> None:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
state.active_count = max(0, state.active_count - 1)
if method_name in state.active_by_method:
remaining = state.active_by_method[method_name] - 1
if remaining <= 0:
state.active_by_method.pop(method_name, None)
else:
state.active_by_method[method_name] = remaining
state.last_ended_ts = end_epoch
state.last_elapsed_ms = elapsed_ms
state.last_error = None if error is None else repr(error)
if state.active_count == 0:
self._operation_state_cv.notify_all()
logger.debug(
"ISO:registry_op_end instance_id=%s method=%s end_ts=%.6f elapsed_ms=%.3f error=%s",
instance_id,
method_name,
end_epoch,
elapsed_ms,
None if error is None else type(error).__name__,
)
def _run_operation_with_lease(self, instance_id: str, method_name: str, fn):
with self._operation_state_cv:
state = self._get_or_create_operation_state(instance_id)
lease = state.lease
with lease:
_, start_perf = self._begin_operation(instance_id, method_name)
try:
result = fn()
except Exception as exc:
self._end_operation(instance_id, method_name, start_perf, error=exc)
raise
self._end_operation(instance_id, method_name, start_perf)
return result
def _snapshot_operation_state(self, instance_id: str) -> dict[str, Any]:
with self._operation_state_cv:
state = self._operation_states.get(instance_id)
if state is None:
return {
"instance_id": instance_id,
"active_count": 0,
"active_methods": [],
"total_operations": 0,
"last_method": None,
"last_started_ts": None,
"last_ended_ts": None,
"last_elapsed_ms": None,
"last_error": None,
"last_thread_id": None,
"last_loop_id": None,
}
return {
"instance_id": instance_id,
"active_count": state.active_count,
"active_methods": sorted(state.active_by_method.keys()),
"total_operations": state.total_operations,
"last_method": state.last_method,
"last_started_ts": state.last_started_ts,
"last_ended_ts": state.last_ended_ts,
"last_elapsed_ms": state.last_elapsed_ms,
"last_error": state.last_error,
"last_thread_id": state.last_thread_id,
"last_loop_id": state.last_loop_id,
}
def unregister_sync(self, instance_id: str) -> None:
with self._operation_state_cv:
instance = self._registry.pop(instance_id, None)
if instance is not None:
self._id_map.pop(id(instance), None)
self._pending_cleanup_ids.discard(instance_id)
self._operation_states.pop(instance_id, None)
self._operation_state_cv.notify_all()
async def get_operation_state(self, instance_id: str) -> dict[str, Any]:
return self._snapshot_operation_state(instance_id)
async def get_all_operation_states(self) -> dict[str, dict[str, Any]]:
with self._operation_state_cv:
ids = sorted(self._operation_states.keys())
return {instance_id: self._snapshot_operation_state(instance_id) for instance_id in ids}
async def wait_for_idle(self, instance_id: str, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
active = self._operation_states.get(instance_id)
if active is None or active.active_count == 0:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def wait_all_idle(self, timeout_ms: int = 0) -> bool:
timeout_s = None if timeout_ms <= 0 else (timeout_ms / 1000.0)
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
with self._operation_state_cv:
while True:
has_active = any(
state.active_count > 0 for state in self._operation_states.values()
)
if not has_active:
return True
if deadline is None:
self._operation_state_cv.wait()
continue
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
self._operation_state_cv.wait(timeout=remaining)
async def clone(self, instance_id: str) -> str: async def clone(self, instance_id: str) -> str:
instance = self._get_instance(instance_id) instance = self._get_instance(instance_id)
@ -73,8 +256,14 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
) )
registry = ModelSamplingRegistry() registry = ModelSamplingRegistry()
sampling_id = registry.register(result) # Preserve identity when upstream already returned a proxy. Re-registering
# a proxy object creates proxy-of-proxy call chains.
if isinstance(result, ModelSamplingProxy):
sampling_id = result._instance_id
else:
sampling_id = registry.register(result)
return ModelSamplingProxy(sampling_id, registry) return ModelSamplingProxy(sampling_id, registry)
return detach_if_grad(result) return detach_if_grad(result)
async def get_model_options(self, instance_id: str) -> dict: async def get_model_options(self, instance_id: str) -> dict:
@ -163,7 +352,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return self._get_instance(instance_id).lowvram_patch_counter() return self._get_instance(instance_id).lowvram_patch_counter()
async def memory_required(self, instance_id: str, input_shape: Any) -> Any: async def memory_required(self, instance_id: str, input_shape: Any) -> Any:
return self._get_instance(instance_id).memory_required(input_shape) return self._run_operation_with_lease(
instance_id,
"memory_required",
lambda: self._get_instance(instance_id).memory_required(input_shape),
)
async def is_dynamic(self, instance_id: str) -> bool: async def is_dynamic(self, instance_id: str) -> bool:
instance = self._get_instance(instance_id) instance = self._get_instance(instance_id)
@ -186,7 +379,11 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
return None return None
async def model_dtype(self, instance_id: str) -> Any: async def model_dtype(self, instance_id: str) -> Any:
return self._get_instance(instance_id).model_dtype() return self._run_operation_with_lease(
instance_id,
"model_dtype",
lambda: self._get_instance(instance_id).model_dtype(),
)
async def model_patches_to(self, instance_id: str, device: Any) -> Any: async def model_patches_to(self, instance_id: str, device: Any) -> Any:
return self._get_instance(instance_id).model_patches_to(device) return self._get_instance(instance_id).model_patches_to(device)
@ -198,8 +395,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
extra_memory: Any, extra_memory: Any,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> Any: ) -> Any:
return self._get_instance(instance_id).partially_load( return self._run_operation_with_lease(
device, extra_memory, force_patch_weights=force_patch_weights instance_id,
"partially_load",
lambda: self._get_instance(instance_id).partially_load(
device, extra_memory, force_patch_weights=force_patch_weights
),
) )
async def partially_unload( async def partially_unload(
@ -209,8 +410,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
memory_to_free: int = 0, memory_to_free: int = 0,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> int: ) -> int:
return self._get_instance(instance_id).partially_unload( return self._run_operation_with_lease(
device_to, memory_to_free, force_patch_weights instance_id,
"partially_unload",
lambda: self._get_instance(instance_id).partially_unload(
device_to, memory_to_free, force_patch_weights
),
) )
async def load( async def load(
@ -221,8 +426,12 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
force_patch_weights: bool = False, force_patch_weights: bool = False,
full_load: bool = False, full_load: bool = False,
) -> None: ) -> None:
self._get_instance(instance_id).load( self._run_operation_with_lease(
device_to, lowvram_model_memory, force_patch_weights, full_load instance_id,
"load",
lambda: self._get_instance(instance_id).load(
device_to, lowvram_model_memory, force_patch_weights, full_load
),
) )
async def patch_model( async def patch_model(
@ -233,20 +442,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
load_weights: bool = True, load_weights: bool = True,
force_patch_weights: bool = False, force_patch_weights: bool = False,
) -> None: ) -> None:
try: def _invoke() -> None:
self._get_instance(instance_id).patch_model( try:
device_to, lowvram_model_memory, load_weights, force_patch_weights self._get_instance(instance_id).patch_model(
) device_to, lowvram_model_memory, load_weights, force_patch_weights
except AttributeError as e: )
logger.error( except AttributeError as e:
f"Isolation Error: Failed to patch model attribute: {e}. Skipping." logger.error(
) f"Isolation Error: Failed to patch model attribute: {e}. Skipping."
return )
return
self._run_operation_with_lease(instance_id, "patch_model", _invoke)
async def unpatch_model( async def unpatch_model(
self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True self, instance_id: str, device_to: Any = None, unpatch_weights: bool = True
) -> None: ) -> None:
self._get_instance(instance_id).unpatch_model(device_to, unpatch_weights) self._run_operation_with_lease(
instance_id,
"unpatch_model",
lambda: self._get_instance(instance_id).unpatch_model(
device_to, unpatch_weights
),
)
async def detach(self, instance_id: str, unpatch_all: bool = True) -> None: async def detach(self, instance_id: str, unpatch_all: bool = True) -> None:
self._get_instance(instance_id).detach(unpatch_all) self._get_instance(instance_id).detach(unpatch_all)
@ -262,26 +480,29 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
self._get_instance(instance_id).pre_run() self._get_instance(instance_id).pre_run()
async def cleanup(self, instance_id: str) -> None: async def cleanup(self, instance_id: str) -> None:
try: def _invoke() -> None:
instance = self._get_instance(instance_id) try:
except Exception: instance = self._get_instance(instance_id)
logger.debug( except Exception:
"ModelPatcher cleanup requested for missing instance %s", logger.debug(
instance_id, "ModelPatcher cleanup requested for missing instance %s",
exc_info=True, instance_id,
) exc_info=True,
return )
return
try: try:
instance.cleanup() instance.cleanup()
finally: finally:
with self._lock: with self._lock:
self._pending_cleanup_ids.add(instance_id) self._pending_cleanup_ids.add(instance_id)
gc.collect() gc.collect()
self._run_operation_with_lease(instance_id, "cleanup", _invoke)
def sweep_pending_cleanup(self) -> int: def sweep_pending_cleanup(self) -> int:
removed = 0 removed = 0
with self._lock: with self._operation_state_cv:
pending_ids = list(self._pending_cleanup_ids) pending_ids = list(self._pending_cleanup_ids)
self._pending_cleanup_ids.clear() self._pending_cleanup_ids.clear()
for instance_id in pending_ids: for instance_id in pending_ids:
@ -289,17 +510,21 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
if instance is None: if instance is None:
continue continue
self._id_map.pop(id(instance), None) self._id_map.pop(id(instance), None)
self._operation_states.pop(instance_id, None)
removed += 1 removed += 1
self._operation_state_cv.notify_all()
gc.collect() gc.collect()
return removed return removed
def purge_all(self) -> int: def purge_all(self) -> int:
with self._lock: with self._operation_state_cv:
removed = len(self._registry) removed = len(self._registry)
self._registry.clear() self._registry.clear()
self._id_map.clear() self._id_map.clear()
self._pending_cleanup_ids.clear() self._pending_cleanup_ids.clear()
self._operation_states.clear()
self._operation_state_cv.notify_all()
gc.collect() gc.collect()
return removed return removed
@ -743,17 +968,52 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_memory_required( async def inner_model_memory_required(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.memory_required(*args, **kwargs) return self._run_operation_with_lease(
instance_id,
"inner_model_memory_required",
lambda: self._get_instance(instance_id).model.memory_required(
*args, **kwargs
),
)
async def inner_model_extra_conds_shapes( async def inner_model_extra_conds_shapes(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.extra_conds_shapes(*args, **kwargs) return self._run_operation_with_lease(
instance_id,
"inner_model_extra_conds_shapes",
lambda: self._get_instance(instance_id).model.extra_conds_shapes(
*args, **kwargs
),
)
async def inner_model_extra_conds( async def inner_model_extra_conds(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return self._get_instance(instance_id).model.extra_conds(*args, **kwargs) def _invoke() -> Any:
result = self._get_instance(instance_id).model.extra_conds(*args, **kwargs)
try:
import torch
import comfy.conds
except Exception:
return result
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
if isinstance(obj, comfy.conds.CONDRegular):
return type(obj)(_to_cpu(obj.cond))
return obj
return _to_cpu(result)
return self._run_operation_with_lease(instance_id, "inner_model_extra_conds", _invoke)
async def inner_model_state_dict( async def inner_model_state_dict(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
@ -767,82 +1027,177 @@ class ModelPatcherRegistry(BaseRegistry[Any]):
async def inner_model_apply_model( async def inner_model_apply_model(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) def _invoke() -> Any:
target = getattr(instance, "load_device", None) import torch
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj): instance = self._get_instance(instance_id)
if target is None: target = getattr(instance, "load_device", None)
if target is None and args and hasattr(args[0], "device"):
target = args[0].device
elif target is None:
for v in kwargs.values():
if hasattr(v, "device"):
target = v.device
break
def _move(obj):
if target is None:
return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj return obj
if isinstance(obj, (tuple, list)):
return type(obj)(_move(o) for o in obj)
if hasattr(obj, "to"):
return obj.to(target)
return obj
moved_args = tuple(_move(a) for a in args) moved_args = tuple(_move(a) for a in args)
moved_kwargs = {k: _move(v) for k, v in kwargs.items()} moved_kwargs = {k: _move(v) for k, v in kwargs.items()}
result = instance.model.apply_model(*moved_args, **moved_kwargs) result = instance.model.apply_model(*moved_args, **moved_kwargs)
return detach_if_grad(_move(result)) moved_result = detach_if_grad(_move(result))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore device placement in the child process.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "inner_model_apply_model", _invoke)
async def process_latent_in( async def process_latent_in(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
return detach_if_grad( import torch
self._get_instance(instance_id).model.process_latent_in(*args, **kwargs)
) def _invoke() -> Any:
instance = self._get_instance(instance_id)
result = detach_if_grad(instance.model.process_latent_in(*args, **kwargs))
# DynamicVRAM + isolation: returning CUDA tensors across RPC can stall
# at the transport boundary. Marshal dynamic-path results as CPU and let
# the proxy restore placement when needed.
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(result)
return result
return self._run_operation_with_lease(instance_id, "process_latent_in", _invoke)
async def process_latent_out( async def process_latent_out(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) import torch
result = instance.model.process_latent_out(*args, **kwargs)
try: def _invoke() -> Any:
target = None instance = self._get_instance(instance_id)
if args and hasattr(args[0], "device"): result = instance.model.process_latent_out(*args, **kwargs)
target = args[0].device moved_result = None
elif kwargs: try:
for v in kwargs.values(): target = None
if hasattr(v, "device"): if args and hasattr(args[0], "device"):
target = v.device target = args[0].device
break elif kwargs:
if target is not None and hasattr(result, "to"): for v in kwargs.values():
return detach_if_grad(result.to(target)) if hasattr(v, "device"):
except Exception: target = v.device
logger.debug( break
"process_latent_out: failed to move result to target device", if target is not None and hasattr(result, "to"):
exc_info=True, moved_result = detach_if_grad(result.to(target))
) except Exception:
return detach_if_grad(result) logger.debug(
"process_latent_out: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(instance_id, "process_latent_out", _invoke)
async def scale_latent_inpaint( async def scale_latent_inpaint(
self, instance_id: str, args: tuple, kwargs: dict self, instance_id: str, args: tuple, kwargs: dict
) -> Any: ) -> Any:
instance = self._get_instance(instance_id) import torch
result = instance.model.scale_latent_inpaint(*args, **kwargs)
try: def _invoke() -> Any:
target = None instance = self._get_instance(instance_id)
if args and hasattr(args[0], "device"): result = instance.model.scale_latent_inpaint(*args, **kwargs)
target = args[0].device moved_result = None
elif kwargs: try:
for v in kwargs.values(): target = None
if hasattr(v, "device"): if args and hasattr(args[0], "device"):
target = v.device target = args[0].device
break elif kwargs:
if target is not None and hasattr(result, "to"): for v in kwargs.values():
return detach_if_grad(result.to(target)) if hasattr(v, "device"):
except Exception: target = v.device
logger.debug( break
"scale_latent_inpaint: failed to move result to target device", if target is not None and hasattr(result, "to"):
exc_info=True, moved_result = detach_if_grad(result.to(target))
) except Exception:
return detach_if_grad(result) logger.debug(
"scale_latent_inpaint: failed to move result to target device",
exc_info=True,
)
if moved_result is None:
moved_result = detach_if_grad(result)
is_dynamic_fn = getattr(instance, "is_dynamic", None)
if callable(is_dynamic_fn) and is_dynamic_fn():
def _to_cpu(obj: Any) -> Any:
if torch.is_tensor(obj):
return obj.detach().cpu() if obj.device.type != "cpu" else obj
if isinstance(obj, dict):
return {k: _to_cpu(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_cpu(v) for v in obj]
if isinstance(obj, tuple):
return tuple(_to_cpu(v) for v in obj)
return obj
return _to_cpu(moved_result)
return moved_result
return self._run_operation_with_lease(
instance_id, "scale_latent_inpaint", _invoke
)
async def load_lora( async def load_lora(
self, self,

View File

@ -3,6 +3,9 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os
import threading
import time
from typing import Any from typing import Any
from comfy.isolation.proxies.base import ( from comfy.isolation.proxies.base import (
@ -16,6 +19,22 @@ from comfy.isolation.proxies.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _describe_value(obj: Any) -> str:
try:
import torch
except Exception:
torch = None
try:
if torch is not None and isinstance(obj, torch.Tensor):
return (
"Tensor(shape=%s,dtype=%s,device=%s,id=%s)"
% (tuple(obj.shape), obj.dtype, obj.device, id(obj))
)
except Exception:
pass
return "%s(id=%s)" % (type(obj).__name__, id(obj))
def _prefer_device(*tensors: Any) -> Any: def _prefer_device(*tensors: Any) -> Any:
try: try:
import torch import torch
@ -49,6 +68,24 @@ def _to_device(obj: Any, device: Any) -> Any:
return obj return obj
def _to_cpu_for_rpc(obj: Any) -> Any:
try:
import torch
except Exception:
return obj
if isinstance(obj, torch.Tensor):
t = obj.detach() if obj.requires_grad else obj
if t.is_cuda:
return t.to("cpu")
return t
if isinstance(obj, (list, tuple)):
converted = [_to_cpu_for_rpc(x) for x in obj]
return type(obj)(converted) if isinstance(obj, tuple) else converted
if isinstance(obj, dict):
return {k: _to_cpu_for_rpc(v) for k, v in obj.items()}
return obj
class ModelSamplingRegistry(BaseRegistry[Any]): class ModelSamplingRegistry(BaseRegistry[Any]):
_type_prefix = "modelsampling" _type_prefix = "modelsampling"
@ -199,14 +236,70 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
rpc = self._get_rpc() rpc = self._get_rpc()
method = getattr(rpc, method_name) method = getattr(rpc, method_name)
result = method(self._instance_id, *args) result = method(self._instance_id, *args)
timeout_ms = self._rpc_timeout_ms()
start_epoch = time.time()
start_perf = time.perf_counter()
thread_id = threading.get_ident()
call_id = "%s:%s:%s:%.6f" % (
self._instance_id,
method_name,
thread_id,
start_perf,
)
logger.debug(
"ISO:modelsampling_rpc_start method=%s instance_id=%s call_id=%s start_ts=%.6f thread=%s timeout_ms=%s",
method_name,
self._instance_id,
call_id,
start_epoch,
thread_id,
timeout_ms,
)
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
result = asyncio.wait_for(result, timeout=timeout_ms / 1000.0)
try: try:
asyncio.get_running_loop() asyncio.get_running_loop()
return run_coro_in_new_loop(result) out = run_coro_in_new_loop(result)
except RuntimeError: except RuntimeError:
loop = get_thread_loop() loop = get_thread_loop()
return loop.run_until_complete(result) out = loop.run_until_complete(result)
return result else:
out = result
logger.debug(
"ISO:modelsampling_rpc_after_await method=%s instance_id=%s call_id=%s out=%s",
method_name,
self._instance_id,
call_id,
_describe_value(out),
)
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:modelsampling_rpc_end method=%s instance_id=%s call_id=%s elapsed_ms=%.3f thread=%s",
method_name,
self._instance_id,
call_id,
elapsed_ms,
thread_id,
)
logger.debug(
"ISO:modelsampling_rpc_return method=%s instance_id=%s call_id=%s",
method_name,
self._instance_id,
call_id,
)
return out
@staticmethod
def _rpc_timeout_ms() -> int:
raw = os.environ.get(
"COMFY_ISOLATION_MODEL_SAMPLING_RPC_TIMEOUT_MS",
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "30000"),
)
try:
timeout_ms = int(raw)
except ValueError:
timeout_ms = 30000
return max(1, timeout_ms)
@property @property
def sigma_min(self) -> Any: def sigma_min(self) -> Any:
@ -235,10 +328,24 @@ class ModelSamplingProxy(BaseProxy[ModelSamplingRegistry]):
def noise_scaling( def noise_scaling(
self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False self, sigma: Any, noise: Any, latent_image: Any, max_denoise: bool = False
) -> Any: ) -> Any:
return self._call("noise_scaling", sigma, noise, latent_image, max_denoise) preferred_device = _prefer_device(noise, latent_image)
out = self._call(
"noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(noise),
_to_cpu_for_rpc(latent_image),
max_denoise,
)
return _to_device(out, preferred_device)
def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any: def inverse_noise_scaling(self, sigma: Any, latent: Any) -> Any:
return self._call("inverse_noise_scaling", sigma, latent) preferred_device = _prefer_device(latent)
out = self._call(
"inverse_noise_scaling",
_to_cpu_for_rpc(sigma),
_to_cpu_for_rpc(latent),
)
return _to_device(out, preferred_device)
def timestep(self, sigma: Any) -> Any: def timestep(self, sigma: Any) -> Any:
return self._call("timestep", sigma) return self._call("timestep", sigma)

View File

@ -2,9 +2,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import concurrent.futures
import logging import logging
import os import os
import threading import threading
import time
import weakref import weakref
from typing import Any, Callable, Dict, Generic, Optional, TypeVar from typing import Any, Callable, Dict, Generic, Optional, TypeVar
@ -119,6 +121,24 @@ def set_global_loop(loop: asyncio.AbstractEventLoop) -> None:
class BaseProxy(Generic[T]): class BaseProxy(Generic[T]):
_registry_class: type = BaseRegistry # type: ignore[type-arg] _registry_class: type = BaseRegistry # type: ignore[type-arg]
__module__: str = "comfy.isolation.proxies.base" __module__: str = "comfy.isolation.proxies.base"
_TIMEOUT_RPC_METHODS = frozenset(
{
"partially_load",
"partially_unload",
"load",
"patch_model",
"unpatch_model",
"inner_model_apply_model",
"memory_required",
"model_dtype",
"inner_model_memory_required",
"inner_model_extra_conds_shapes",
"inner_model_extra_conds",
"process_latent_in",
"process_latent_out",
"scale_latent_inpaint",
}
)
def __init__( def __init__(
self, self,
@ -148,39 +168,89 @@ class BaseProxy(Generic[T]):
) )
return self._rpc_caller return self._rpc_caller
def _rpc_timeout_ms_for_method(self, method_name: str) -> Optional[int]:
if method_name not in self._TIMEOUT_RPC_METHODS:
return None
try:
timeout_ms = int(
os.environ.get("COMFY_ISOLATION_LOAD_RPC_TIMEOUT_MS", "120000")
)
except ValueError:
timeout_ms = 120000
return max(1, timeout_ms)
def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any: def _call_rpc(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
rpc = self._get_rpc() rpc = self._get_rpc()
method = getattr(rpc, method_name) method = getattr(rpc, method_name)
timeout_ms = self._rpc_timeout_ms_for_method(method_name)
coro = method(self._instance_id, *args, **kwargs) coro = method(self._instance_id, *args, **kwargs)
if timeout_ms is not None:
coro = asyncio.wait_for(coro, timeout=timeout_ms / 1000.0)
# If we have a global loop (Main Thread Loop), use it for dispatch from worker threads start_epoch = time.time()
if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running(): start_perf = time.perf_counter()
try: thread_id = threading.get_ident()
# If we are already in the global loop, we can't block on it? try:
# Actually, this method is synchronous (__getattr__ -> lambda). running_loop = asyncio.get_running_loop()
# If called from async context in main loop, we need to handle that. loop_id: Optional[int] = id(running_loop)
curr_loop = asyncio.get_running_loop() except RuntimeError:
if curr_loop is _GLOBAL_LOOP: loop_id = None
# We are in the main loop. We cannot await/block here if we are just a sync function. logger.debug(
# But proxies are often called from sync code. "ISO:rpc_start proxy=%s method=%s instance_id=%s start_ts=%.6f "
# If called from sync code in main loop, creating a new loop is bad. "thread=%s loop=%s timeout_ms=%s",
# But we can't await `coro`. self.__class__.__name__,
# This implies proxies MUST be awaited if called from async context? method_name,
# Existing code used `run_coro_in_new_loop` which is weird. self._instance_id,
# Let's trust that if we are in a thread (RuntimeError on get_running_loop), start_epoch,
# we use run_coroutine_threadsafe. thread_id,
pass loop_id,
except RuntimeError: timeout_ms,
# No running loop - we are in a worker thread. )
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result()
try: try:
asyncio.get_running_loop() # If we have a global loop (Main Thread Loop), use it for dispatch from worker threads
return run_coro_in_new_loop(coro) if _GLOBAL_LOOP is not None and _GLOBAL_LOOP.is_running():
except RuntimeError: try:
loop = get_thread_loop() curr_loop = asyncio.get_running_loop()
return loop.run_until_complete(coro) if curr_loop is _GLOBAL_LOOP:
pass
except RuntimeError:
# No running loop - we are in a worker thread.
future = asyncio.run_coroutine_threadsafe(coro, _GLOBAL_LOOP)
return future.result(
timeout=(timeout_ms / 1000.0) if timeout_ms is not None else None
)
try:
asyncio.get_running_loop()
return run_coro_in_new_loop(coro)
except RuntimeError:
loop = get_thread_loop()
return loop.run_until_complete(coro)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
except concurrent.futures.TimeoutError as exc:
raise TimeoutError(
f"Isolation RPC timeout in {self.__class__.__name__}.{method_name} "
f"(instance_id={self._instance_id}, timeout_ms={timeout_ms})"
) from exc
finally:
end_epoch = time.time()
elapsed_ms = (time.perf_counter() - start_perf) * 1000.0
logger.debug(
"ISO:rpc_end proxy=%s method=%s instance_id=%s end_ts=%.6f "
"elapsed_ms=%.3f thread=%s loop=%s",
self.__class__.__name__,
method_name,
self._instance_id,
end_epoch,
elapsed_ms,
thread_id,
loop_id,
)
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
return {"_instance_id": self._instance_id} return {"_instance_id": self._instance_id}

View File

@ -192,7 +192,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if isolation_active: if isolation_active:
target_device = sigmas.device target_device = sigmas.device
if x.device != target_device: if x.device != target_device:

View File

@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
# unpatchify # unpatchify
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
class SCAILWanModel(WanModel):
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
del scail_x
# time embeddings
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.cat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
if scail_pose_seq_len > 0:
x = x[:, :-scail_pose_seq_len]
# unpatchify
x = self.unpatchify(x, grid_sizes)
if reference_latent is not None:
x = x[:, :, reference_latent.shape[2]:]
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
h_scale = h / H_pose
w_scale = w / W_pose
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
reference_latent = None
if "reference_latent" in kwargs:
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]

View File

@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch import torch
import logging import logging
import os
import comfy.ldm.lightricks.av_model import comfy.ldm.lightricks.av_model
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_c import StageC
@ -76,6 +77,7 @@ class ModelType(Enum):
FLUX = 8 FLUX = 8
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -108,17 +110,23 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS: elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
from comfy.cli_args import args
isolation_runtime_enabled = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
class ModelSampling(s, c): class ModelSampling(s, c):
def __reduce__(self): if isolation_runtime_enabled:
"""Ensure pickling yields a proxy instead of failing on local class.""" def __reduce__(self):
try: """Ensure pickling yields a proxy instead of failing on local class."""
from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy try:
registry = ModelSamplingRegistry() from comfy.isolation.model_sampling_proxy import ModelSamplingRegistry, ModelSamplingProxy
ms_id = registry.register(self) registry = ModelSamplingRegistry()
return (ModelSamplingProxy, (ms_id,)) ms_id = registry.register(self)
except Exception as exc: return (ModelSamplingProxy, (ms_id,))
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc except Exception as exc:
raise RuntimeError("Failed to serialize ModelSampling for isolation.") from exc
return ModelSampling(model_config) return ModelSampling(model_config)
@ -998,6 +1006,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None: if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -1050,6 +1062,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None: if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1493,6 +1509,50 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class WAN21_SCAIL(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
pose_latents = self.process_latent_in(pose_latents)
pose_mask = torch.ones_like(pose_latents[:, :4])
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config return dit_config
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {} dit_config = {}
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
@ -498,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo" dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate" dit_config["model_type"] = "animate"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
else: else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v" dit_config["model_type"] = "i2v"
@ -531,8 +533,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1" dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
@ -1053,6 +1054,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
n_layers = count_blocks(state_dict, 'layers.{}.')
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
for k in state_dict: # For zeta chroma
if k not in sd_map:
sd_map[k] = k
elif 'x_embedder.weight' in state_dict: #Flux elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')

View File

@ -180,6 +180,14 @@ def is_ixuca():
return True return True
return False return False
def is_wsl():
version = platform.uname().release
if version.endswith("-Microsoft"):
return True
elif version.endswith("microsoft-standard-WSL2"):
return True
return False
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
global cpu_state global cpu_state
@ -475,6 +483,9 @@ except:
current_loaded_models = [] current_loaded_models = []
def _isolation_mode_enabled():
return args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
sd = module.state_dict() sd = module.state_dict()
@ -554,8 +565,9 @@ class LoadedModel:
if freed >= memory_to_free: if freed >= memory_to_free:
return False return False
self.model.detach(unpatch_weights) self.model.detach(unpatch_weights)
self.model_finalizer.detach() if self.model_finalizer is not None:
self.model_finalizer = None self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None self.real_model = None
return True return True
@ -569,14 +581,15 @@ class LoadedModel:
if self._patcher_finalizer is not None: if self._patcher_finalizer is not None:
self._patcher_finalizer.detach() self._patcher_finalizer.detach()
def dead_state(self):
model_ref_gone = self.model is None
real_model_ref = self.real_model
real_model_ref_gone = callable(real_model_ref) and real_model_ref() is None
return model_ref_gone, real_model_ref_gone
def is_dead(self): def is_dead(self):
# Model is dead if the weakref to model has been garbage collected model_ref_gone, real_model_ref_gone = self.dead_state()
# This can happen with ModelPatcherProxy objects between isolated workflows return model_ref_gone or real_model_ref_gone
if self.model is None:
return True
if self.real_model is None:
return False
return self.real_model() is None
def use_more_memory(extra_memory, loaded_models, device): def use_more_memory(extra_memory, loaded_models, device):
@ -622,7 +635,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" isolation_active = _isolation_mode_enabled()
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
@ -649,12 +662,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram() ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models
#don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand.
#as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size()
memory_required -= current_loaded_models[i].model.loaded_size() memory_to_free = 0
memory_to_free = 0
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
@ -728,7 +740,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for i in to_unload: for i in to_unload:
model_to_unload = current_loaded_models.pop(i) model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False) model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach() if model_to_unload.model_finalizer is not None:
model_to_unload.model_finalizer.detach()
model_to_unload.model_finalizer = None
total_memory_required = {} total_memory_required = {}
@ -792,21 +806,55 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc(): def cleanup_models_gc():
reset_cast_buffers() reset_cast_buffers()
if not _isolation_mode_enabled():
dead_found = False
for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead():
dead_found = True
break
if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i]
if cur.is_dead():
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None)
if model_obj is not None:
cleanup = getattr(model_obj, "cleanup", None)
if callable(cleanup):
cleanup()
return
dead_found = False dead_found = False
has_real_model_leak = False
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if current_loaded_models[i].is_dead(): model_ref_gone, real_model_ref_gone = current_loaded_models[i].dead_state()
if model_ref_gone or real_model_ref_gone:
dead_found = True dead_found = True
break if real_model_ref_gone and not model_ref_gone:
has_real_model_leak = True
if dead_found: if dead_found:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.") if has_real_model_leak:
logging.info("Potential memory leak detected with model NoneType, doing a full garbage collect, for maximum performance avoid circular references in the model code.")
else:
logging.debug("Cleaning stale loaded-model entries with released patcher references.")
gc.collect() gc.collect()
soft_empty_cache() soft_empty_cache()
for i in range(len(current_loaded_models) - 1, -1, -1): for i in range(len(current_loaded_models) - 1, -1, -1):
cur = current_loaded_models[i] cur = current_loaded_models[i]
if cur.is_dead(): model_ref_gone, real_model_ref_gone = cur.dead_state()
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.") if model_ref_gone or real_model_ref_gone:
if real_model_ref_gone and not model_ref_gone:
logging.warning("WARNING, memory leak with model NoneType. Please make sure it is not being referenced from somewhere.")
else:
logging.debug("Cleaning stale loaded-model entry with released patcher reference.")
leaked = current_loaded_models.pop(i) leaked = current_loaded_models.pop(i)
model_obj = getattr(leaked, "model", None) model_obj = getattr(leaked, "model", None)
if model_obj is not None: if model_obj is not None:
@ -824,7 +872,11 @@ def archive_model_dtypes(model):
def cleanup_models(): def cleanup_models():
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None: real_model_ref = current_loaded_models[i].real_model
if real_model_ref is None:
to_delete = [i] + to_delete
continue
if callable(real_model_ref) and real_model_ref() is None:
to_delete = [i] + to_delete to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:

View File

@ -308,15 +308,22 @@ class ModelPatcher:
def get_free_memory(self, device): def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device) return comfy.model_management.get_free_memory(device)
def clone(self, disable_dynamic=False): def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__ class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic: if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) if model_override is None:
model = temp_model_patcher.model if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -325,13 +332,12 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy() n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
# attachments # attachments
n.attachments = {} n.attachments = {}
for k in self.attachments: for k in self.attachments:
@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
del self.model.model_loaded_weight_memory del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"): if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {} self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
assert load_device is not None assert load_device is not None
def is_dynamic(self): def is_dynamic(self):
@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass pass
def get_non_dynamic_delegate(self):
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
return model_patcher
CoreModelPatcher = ModelPatcher CoreModelPatcher = ModelPatcher

View File

@ -66,6 +66,18 @@ def convert_cond(cond):
out.append(temp) out.append(temp)
return out return out
def cond_has_hooks(cond):
for c in cond:
temp = c[1]
if "hooks" in temp:
return True
if "control" in temp:
control = temp["control"]
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
return True
return False
def get_additional_models(conds, dtype): def get_additional_models(conds, dtype):
"""loads additional models in conditioning""" """loads additional models in conditioning"""
cnets: list[ControlBase] = [] cnets: list[ControlBase] = []

View File

@ -212,10 +212,11 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
_calc_cond_batch, _calc_cond_batch,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
) )
return executor.execute(model, conds, x_in, timestep, model_options) result = executor.execute(model, conds, x_in, timestep, model_options)
return result
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
out_conds = [] out_conds = []
out_counts = [] out_counts = []
# separate conds by matching hooks # separate conds by matching hooks
@ -272,7 +273,8 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
for k, v in to_run[tt][0].conditioning.items(): for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size()) cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: memory_required = model.memory_required(input_shape, cond_shapes=cond_shapes)
if memory_required * 1.5 < free_memory:
to_batch = batch_amount to_batch = batch_amount
break break
@ -411,7 +413,7 @@ class KSamplerX0Inpaint:
self.inner_model = model self.inner_model = model
self.sigmas = sigmas self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None):
isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1" isolation_active = args.use_process_isolation or os.environ.get("PYISOLATE_CHILD") == "1"
if denoise_mask is not None: if denoise_mask is not None:
if isolation_active and denoise_mask.device != x.device: if isolation_active and denoise_mask.device != x.device:
denoise_mask = denoise_mask.to(x.device) denoise_mask = denoise_mask.to(x.device)
@ -777,7 +779,11 @@ class KSAMPLER(Sampler):
else: else:
model_k.noise = noise model_k.noise = noise
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas[0], noise, latent_image, self.max_denoise(model_wrap, sigmas)) max_denoise = self.max_denoise(model_wrap, sigmas)
model_sampling = model_wrap.inner_model.model_sampling
noise = model_sampling.noise_scaling(
sigmas[0], noise, latent_image, max_denoise
)
k_callback = None k_callback = None
total_steps = len(sigmas) - 1 total_steps = len(sigmas) - 1
@ -982,6 +988,8 @@ class CFGGuider:
def inner_set_conds(self, conds): def inner_set_conds(self, conds):
for k in conds: for k in conds:
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):

View File

@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@ -233,7 +233,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model) model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention #Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32) self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@ -267,9 +268,9 @@ class CLIP:
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {} self.tokenizer_options = {}
def clone(self): def clone(self, disable_dynamic=False):
n = CLIP(no_init=True) n = CLIP(no_init=True)
n.patcher = self.patcher.clone() n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
@ -1164,14 +1165,21 @@ class CLIPType(Enum):
LONGCAT_IMAGE = 26 LONGCAT_IMAGE = 26
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
return clip.patcher
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None: if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd) clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
return clip
class TEModel(Enum): class TEModel(Enum):
@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
return {} return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = state_dicts clip_data = state_dicts
class EmptyClass: class EmptyClass:
@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model: if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
disable_dynamic=disable_dynamic) disable_dynamic=disable_dynamic)
return model return model
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
embedding_directory=embedding_directory, output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return clip.patcher
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None clip = None
clipvision = None clipvision = None
@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd) parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
else: else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

View File

@ -1268,6 +1268,16 @@ class WAN21_FlowRVS(WAN21_T2V):
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
return out return out
class WAN21_SCAIL(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE): class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hunyuan3d2", "image_model": "hunyuan3d2",
@ -1710,6 +1720,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model = "gemini-3.1-flash-image-preview" model = "gemini-3.1-flash-image-preview"
if response_modalities == "IMAGE+TEXT":
raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.")
parts: list[GeminiPart] = [GeminiPart(text=prompt)] parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images is not None: if images is not None:
@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"response_modalities", "response_modalities",
options=["IMAGE"], options=["IMAGE", "IMAGE+TEXT"],
advanced=True, advanced=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
IO.String.Output(),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,

View File

@ -20,7 +20,7 @@ class JobStatus:
# Media types that can be previewed in the frontend # Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
# 3D file extensions for preview fallback (no dedicated media_type exists) # 3D file extensions for preview fallback (no dedicated media_type exists)
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
normalized[node_id] = normalized_node normalized[node_id] = normalized_node
return normalized return normalized
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
TEXT_PREVIEW_MAX_LENGTH = 1024
def _create_text_preview(value: str) -> dict:
"""Create a text preview dict with optional truncation.
Returns:
dict with 'content' and optionally 'truncated' flag
"""
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
return {'content': value}
return {
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
'truncated': True
}
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data. """Extract create_time and workflow_id from extra_data.
@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue continue
for item in items: for item in items:
normalized = normalize_output_item(item) if not isinstance(item, dict):
if normalized is None: # Handle text outputs (non-dict items like strings or tuples)
continue normalized = normalize_output_item(item)
if normalized is None:
# Not a 3D file string — check for text preview
if media_type == 'text':
count += 1
if preview_output is None:
if isinstance(item, tuple):
text_value = item[0] if item else ''
else:
text_value = str(item)
text_preview = _create_text_preview(text_value)
enriched = {
**text_preview,
'nodeId': node_id,
'mediaType': media_type
}
if fallback_preview is None:
fallback_preview = enriched
continue
# normalize_output_item returned a dict (e.g. 3D file)
item = normalized
count += 1 count += 1
if preview_output is not None: if preview_output is not None:
continue continue
if isinstance(normalized, dict) and is_previewable(media_type, normalized): if is_previewable(media_type, item):
enriched = { enriched = {
**normalized, **item,
'nodeId': node_id, 'nodeId': node_id,
} }
if 'mediaType' not in normalized: if 'mediaType' not in item:
enriched['mediaType'] = media_type enriched['mediaType'] = media_type
if normalized.get('type') == 'output': if item.get('type') == 'output':
preview_output = enriched preview_output = enriched
elif fallback_preview is None: elif fallback_preview is None:
fallback_preview = enriched fallback_preview = enriched

View File

@ -248,7 +248,7 @@ class SetClipHooks:
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None: if hooks is not None:
clip = clip.clone() clip = clip.clone(disable_dynamic=True)
if apply_to_conds: if apply_to_conds:
clip.apply_hooks_to_conds = hooks clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone() clip.patcher.forced_hooks = hooks.clone()

View File

@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="Mahiro", node_id="Mahiro",
display_name="Mahiro CFG", display_name="Positive-Biased Guidance",
category="_for_testing", category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[ inputs=[
@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
io.Model.Output(display_name="patched_model"), io.Model.Output(display_name="patched_model"),
], ],
is_experimental=True, is_experimental=True,
search_aliases=[
"mahiro",
"mahiro cfg",
"similarity-adaptive guidance",
"positive-biased cfg",
],
) )
@classmethod @classmethod
def execute(cls, model) -> io.NodeOutput: def execute(cls, model) -> io.NodeOutput:
m = model.clone() m = model.clone()
def mahiro_normd(args): def mahiro_normd(args):
scale: float = args['cond_scale'] scale: float = args["cond_scale"]
cond_p: torch.Tensor = args['cond_denoised'] cond_p: torch.Tensor = args["cond_denoised"]
uncond_p: torch.Tensor = args['uncond_denoised'] uncond_p: torch.Tensor = args["uncond_denoised"]
#naive leap # naive leap
leap = cond_p * scale leap = cond_p * scale
#sim with uncond leap # sim with uncond leap
u_leap = uncond_p * scale u_leap = uncond_p * scale
cfg = args["denoised"] cfg = args["denoised"]
merge = (leap + cfg) / 2 merge = (leap + cfg) / 2
normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
normm = torch.sqrt(merge.abs()) * merge.sign() normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean() sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1) simsc = 2 * (sim + 1)
wm = (simsc*cfg + (4-simsc)*leap) / 4 wm = (simsc * cfg + (4 - simsc) * leap) / 4
return wm return wm
m.set_model_sampler_post_cfg_function(mahiro_normd) m.set_model_sampler_post_cfg_function(mahiro_normd)
return io.NodeOutput(m) return io.NodeOutput(m)

View File

@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanExtension(ComfyExtension): class WanExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
WanAnimateToVideo, WanAnimateToVideo,
Wan22ImageToVideoLatent, Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo, WanInfiniteTalkToVideo,
WanSCAILToVideo,
] ]
async def comfy_entrypoint() -> WanExtension: async def comfy_entrypoint() -> WanExtension:

View File

@ -43,6 +43,8 @@ from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io from comfy_api.latest import io, _io
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = False
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -540,7 +542,17 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG": if args.verbose == "DEBUG":
comfy_aimdo.control.analyze() comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers() comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits() vbar_lib = getattr(comfy_aimdo.model_vbar, "lib", None)
if vbar_lib is not None:
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
else:
global _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED
if not _AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED:
logging.warning(
"DynamicVRAM backend unavailable for watermark reset; "
"skipping vbar reset for this process."
)
_AIMDO_VBAR_RESET_UNAVAILABLE_LOGGED = True
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
@ -669,6 +681,8 @@ class PromptExecutor:
self.success = True self.success = True
async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None: async def _notify_execution_graph_safe(self, class_types: set[str], *, fail_loud: bool = False) -> None:
if not args.use_process_isolation:
return
try: try:
from comfy.isolation import notify_execution_graph from comfy.isolation import notify_execution_graph
await notify_execution_graph(class_types) await notify_execution_graph(class_types)
@ -678,12 +692,34 @@ class PromptExecutor:
logging.debug("][ EX:notify_execution_graph failed", exc_info=True) logging.debug("][ EX:notify_execution_graph failed", exc_info=True)
async def _flush_running_extensions_transport_state_safe(self) -> None: async def _flush_running_extensions_transport_state_safe(self) -> None:
if not args.use_process_isolation:
return
try: try:
from comfy.isolation import flush_running_extensions_transport_state from comfy.isolation import flush_running_extensions_transport_state
await flush_running_extensions_transport_state() await flush_running_extensions_transport_state()
except Exception: except Exception:
logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True) logging.debug("][ EX:flush_running_extensions_transport_state failed", exc_info=True)
async def _wait_model_patcher_quiescence_safe(
self,
*,
fail_loud: bool = False,
timeout_ms: int = 120000,
marker: str = "EX:wait_model_patcher_idle",
) -> None:
if not args.use_process_isolation:
return
try:
from comfy.isolation import wait_for_model_patcher_quiescence
await wait_for_model_patcher_quiescence(
timeout_ms=timeout_ms, fail_loud=fail_loud, marker=marker
)
except Exception:
if fail_loud:
raise
logging.debug("][ EX:wait_model_patcher_quiescence failed", exc_info=True)
def add_message(self, event, data: dict, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = { data = {
**data, **data,
@ -725,16 +761,17 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# Update RPC event loops for all isolated extensions if args.use_process_isolation:
# This is critical for serial workflow execution - each asyncio.run() creates # Update RPC event loops for all isolated extensions.
# a new event loop, and RPC instances must be updated to use it # This is critical for serial workflow execution - each asyncio.run() creates
try: # a new event loop, and RPC instances must be updated to use it.
from comfy.isolation import update_rpc_event_loops try:
update_rpc_event_loops() from comfy.isolation import update_rpc_event_loops
except ImportError: update_rpc_event_loops()
pass # Isolation not available except ImportError:
except Exception as e: pass # Isolation not available
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}") except Exception as e:
logging.getLogger(__name__).warning(f"Failed to update RPC event loops: {e}")
set_preview_method(extra_data.get("preview_method")) set_preview_method(extra_data.get("preview_method"))
@ -754,6 +791,11 @@ class PromptExecutor:
# Boundary cleanup runs at the start of the next workflow in # Boundary cleanup runs at the start of the next workflow in
# isolation mode, matching non-isolated "next prompt" timing. # isolation mode, matching non-isolated "next prompt" timing.
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
await self._wait_model_patcher_quiescence_safe(
fail_loud=False,
timeout_ms=120000,
marker="EX:boundary_cleanup_wait_idle",
)
await self._flush_running_extensions_transport_state_safe() await self._flush_running_extensions_transport_state_safe()
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
@ -794,6 +836,11 @@ class PromptExecutor:
for node_id in execution_list.pendingNodes.keys(): for node_id in execution_list.pendingNodes.keys():
class_type = dynamic_prompt.get_node(node_id)["class_type"] class_type = dynamic_prompt.get_node(node_id)["class_type"]
pending_class_types.add(class_type) pending_class_types.add(class_type)
await self._wait_model_patcher_quiescence_safe(
fail_loud=True,
timeout_ms=120000,
marker="EX:notify_graph_wait_idle",
)
await self._notify_execution_graph_safe(pending_class_types, fail_loud=True) await self._notify_execution_graph_safe(pending_class_types, fail_loud=True)
while not execution_list.is_empty(): while not execution_list.is_empty():

16
main.py
View File

@ -25,6 +25,15 @@ from app.assets.scanner import seed_assets
import itertools import itertools
import logging import logging
import comfy_aimdo.control
if enables_dynamic_vram():
if not comfy_aimdo.control.init():
logging.warning(
"DynamicVRAM requested, but comfy-aimdo failed to initialize early. "
"Will fall back to legacy model loading if device init fails."
)
if '--use-process-isolation' in sys.argv: if '--use-process-isolation' in sys.argv:
from comfy.isolation import initialize_proxies from comfy.isolation import initialize_proxies
initialize_proxies() initialize_proxies()
@ -208,11 +217,6 @@ import gc
if 'torch' in sys.modules: if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils import comfy.utils
if not IS_PYISOLATE_CHILD: if not IS_PYISOLATE_CHILD:
@ -228,7 +232,7 @@ if not IS_PYISOLATE_CHILD:
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
if enables_dynamic_vram(): if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
if comfy.model_management.torch_version_numeric < (2, 8): if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):

View File

@ -1,5 +1,6 @@
import hashlib import hashlib
import torch import torch
import logging
from comfy.cli_args import args from comfy.cli_args import args
@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
return c return c
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
"""
Apply values to conditioning only during [start_percent, end_percent], keeping the
original conditioning active outside that range. Respects existing per-entry ranges.
"""
if start_percent > end_percent:
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
return conditioning
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
c = []
for t in conditioning:
cond_start = t[1].get("start_percent", 0.0)
cond_end = t[1].get("end_percent", 1.0)
intersect_start = max(start_percent, cond_start)
intersect_end = min(end_percent, cond_end)
if intersect_start >= intersect_end: # no overlap: emit unchanged
c.append(t)
continue
if intersect_start > cond_start: # part before the requested range
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
if intersect_end < cond_end: # part after the requested range
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
return c
def pillow(fn, arg): def pillow(fn, arg):
prev_value = None prev_value = None
try: try:

View File

@ -10,6 +10,25 @@ homepage = "https://www.comfy.org/"
repository = "https://github.com/comfyanonymous/ComfyUI" repository = "https://github.com/comfyanonymous/ComfyUI"
documentation = "https://docs.comfy.org/" documentation = "https://docs.comfy.org/"
[tool.comfy.host]
allow_network = false
writable_paths = ["/dev/shm", "/tmp"]
[tool.comfy.host.whitelist]
"ComfyUI-Crystools" = "*"
"ComfyUI-Florence2" = "*"
"ComfyUI-GGUF" = "*"
"ComfyUI-KJNodes" = "*"
"ComfyUI-LTXVideo" = "*"
"ComfyUI-Manager" = "*"
"comfyui-depthanythingv2" = "*"
"comfyui-kjnodes" = "*"
"comfyui-videohelpersuite" = "*"
"comfyui_controlnet_aux" = "*"
"rgthree-comfy" = "*"
"was-ns" = "*"
"websocket_image_save.py" = "*"
[tool.ruff] [tool.ruff]
lint.select = [ lint.select = [
"N805", # invalid-first-argument-name-for-method "N805", # invalid-first-argument-name-for-method

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19 comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.4 comfyui-workflow-templates==0.9.5
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.2 comfy-aimdo>=0.2.4
requests requests
#non essential dependencies: #non essential dependencies:

View File

@ -49,6 +49,12 @@ def mock_provider(mock_releases):
return provider return provider
@pytest.fixture(autouse=True)
def clear_cache():
import utils.install_util
utils.install_util.PACKAGE_VERSIONS = {}
def test_get_release(mock_provider, mock_releases): def test_get_release(mock_provider, mock_releases):
version = "1.0.0" version = "1.0.0"
release = mock_provider.get_release(version) release = mock_provider.get_release(version)

View File

@ -38,13 +38,13 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()""" """Unit tests for is_previewable()"""
def test_previewable_media_types(self): def test_previewable_media_types(self):
"""Images, video, audio, 3d media types should be previewable.""" """Images, video, audio, 3d, text media types should be previewable."""
for media_type in ['images', 'video', 'audio', '3d']: for media_type in ['images', 'video', 'audio', '3d', 'text']:
assert is_previewable(media_type, {}) is True assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self): def test_non_previewable_media_types(self):
"""Other media types should not be previewable.""" """Other media types should not be previewable."""
for media_type in ['latents', 'text', 'metadata', 'files']: for media_type in ['latents', 'metadata', 'files']:
assert is_previewable(media_type, {}) is False assert is_previewable(media_type, {}) is False
def test_3d_extensions_previewable(self): def test_3d_extensions_previewable(self):

View File

@ -1,5 +1,7 @@
from pathlib import Path from pathlib import Path
import sys import sys
import logging
import re
# The path to the requirements.txt file # The path to the requirements.txt file
requirements_path = Path(__file__).parents[1] / "requirements.txt" requirements_path = Path(__file__).parents[1] / "requirements.txt"
@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {requirements_path} {sys.executable} {extra}-m pip install -r {requirements_path}
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
""".strip() """.strip()
def is_valid_version(version: str) -> bool:
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
PACKAGE_VERSIONS = {}
def get_required_packages_versions():
if len(PACKAGE_VERSIONS) > 0:
return PACKAGE_VERSIONS.copy()
out = PACKAGE_VERSIONS
try:
with open(requirements_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().replace(">=", "==")
s = line.split("==")
if len(s) == 2:
version_str = s[-1]
if not is_valid_version(version_str):
logging.error(f"Invalid version format in requirements.txt: {version_str}")
continue
out[s[0]] = version_str
return out.copy()
except FileNotFoundError:
logging.error("requirements.txt not found.")
return None
except Exception as e:
logging.error(f"Error reading requirements.txt: {e}")
return None