mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 21:12:30 +08:00
Merge pull request #13380 from pollockjj/issue61-clean-pyisolate-support
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
pyisolate-support release refresh
This commit is contained in:
commit
ea092cd1e7
@ -44,21 +44,15 @@ def initialize_proxies() -> None:
|
||||
from .child_hooks import is_child_process
|
||||
|
||||
is_child = is_child_process()
|
||||
logger.warning(
|
||||
"%s DIAG:initialize_proxies | is_child=%s | PYISOLATE_CHILD=%s",
|
||||
LOG_PREFIX, is_child, os.environ.get("PYISOLATE_CHILD"),
|
||||
)
|
||||
|
||||
if is_child:
|
||||
from .child_hooks import initialize_child_process
|
||||
|
||||
initialize_child_process()
|
||||
logger.warning("%s DIAG:initialize_proxies child_process initialized", LOG_PREFIX)
|
||||
else:
|
||||
from .host_hooks import initialize_host_process
|
||||
|
||||
initialize_host_process()
|
||||
logger.warning("%s DIAG:initialize_proxies host_process initialized", LOG_PREFIX)
|
||||
if start_shm_forensics is not None:
|
||||
start_shm_forensics()
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from pyisolate.interfaces import IsolationAdapter, SerializerRegistryProtocol # type: ignore[import-untyped]
|
||||
from pyisolate._internal.rpc_protocol import AsyncRPC, ProxiedSingleton # type: ignore[import-untyped]
|
||||
@ -586,29 +586,6 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
|
||||
register_hooks_serializers(registry)
|
||||
|
||||
# Generic Numpy Serializer
|
||||
def serialize_numpy(obj: Any) -> Any:
|
||||
import torch
|
||||
|
||||
try:
|
||||
# Attempt zero-copy conversion to Tensor
|
||||
return torch.from_numpy(obj)
|
||||
except Exception:
|
||||
# Fallback for non-numeric arrays (strings, objects, mixes)
|
||||
return obj.tolist()
|
||||
|
||||
def deserialize_numpy_b64(data: Any) -> Any:
|
||||
"""Deserialize base64-encoded ndarray from sealed worker."""
|
||||
import base64
|
||||
import numpy as np
|
||||
if isinstance(data, dict) and "data" in data and "dtype" in data:
|
||||
raw = base64.b64decode(data["data"])
|
||||
arr = np.frombuffer(raw, dtype=np.dtype(data["dtype"])).reshape(data["shape"])
|
||||
return torch.from_numpy(arr.copy())
|
||||
return data
|
||||
|
||||
registry.register("ndarray", serialize_numpy, deserialize_numpy_b64)
|
||||
|
||||
# -- File3D (comfy_api.latest._util.geometry_types) ---------------------
|
||||
# Origin: comfy_api by ComfyOrg (Alexander Piskun), PR #12129
|
||||
|
||||
@ -873,93 +850,15 @@ class ComfyUIAdapter(IsolationAdapter):
|
||||
|
||||
return
|
||||
|
||||
if api_name == "PromptServerProxy":
|
||||
if api_name == "PromptServerService":
|
||||
if not _IMPORT_TORCH:
|
||||
return
|
||||
# Defer heavy import to child context
|
||||
import server
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
instance = api() if isinstance(api, type) else api
|
||||
proxy = (
|
||||
instance.instance
|
||||
) # PromptServerProxy instance has .instance property returning self
|
||||
|
||||
original_register_route = proxy.register_route
|
||||
|
||||
def register_route_wrapper(
|
||||
method: str, path: str, handler: Callable[..., Any]
|
||||
) -> None:
|
||||
callback_id = rpc.register_callback(handler)
|
||||
loop = getattr(rpc, "loop", None)
|
||||
if loop and loop.is_running():
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
original_register_route(
|
||||
method, path, handler=callback_id, is_callback=True
|
||||
)
|
||||
return None
|
||||
|
||||
proxy.register_route = register_route_wrapper
|
||||
|
||||
class RouteTableDefProxy:
|
||||
def __init__(self, proxy_instance: Any):
|
||||
self.proxy = proxy_instance
|
||||
|
||||
def get(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("GET", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def post(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("POST", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def patch(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PATCH", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def put(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("PUT", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
def delete(
|
||||
self, path: str, **kwargs: Any
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
def decorator(handler: Callable[..., Any]) -> Callable[..., Any]:
|
||||
self.proxy.register_route("DELETE", path, handler)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
proxy.routes = RouteTableDefProxy(proxy)
|
||||
|
||||
stub = PromptServerStub()
|
||||
if (
|
||||
hasattr(server, "PromptServer")
|
||||
and getattr(server.PromptServer, "instance", None) != proxy
|
||||
and getattr(server.PromptServer, "instance", None) is not stub
|
||||
):
|
||||
server.PromptServer.instance = proxy
|
||||
server.PromptServer.instance = stub
|
||||
|
||||
@ -31,7 +31,6 @@ def _load_extra_model_paths() -> None:
|
||||
|
||||
|
||||
def initialize_child_process() -> None:
|
||||
logger.warning("][ DIAG:child_hooks initialize_child_process START")
|
||||
if os.environ.get("PYISOLATE_IMPORT_TORCH", "1") != "0":
|
||||
_load_extra_model_paths()
|
||||
_setup_child_loop_bridge()
|
||||
@ -41,15 +40,12 @@ def initialize_child_process() -> None:
|
||||
from pyisolate._internal.rpc_protocol import get_child_rpc_instance
|
||||
|
||||
rpc = get_child_rpc_instance()
|
||||
logger.warning("][ DIAG:child_hooks RPC instance: %s", rpc is not None)
|
||||
if rpc:
|
||||
_setup_proxy_callers(rpc)
|
||||
logger.warning("][ DIAG:child_hooks proxy callers configured with RPC")
|
||||
else:
|
||||
logger.warning("][ DIAG:child_hooks NO RPC — proxy callers cleared")
|
||||
_setup_proxy_callers()
|
||||
except Exception as e:
|
||||
logger.error(f"][ DIAG:child_hooks Manual RPC Injection failed: {e}")
|
||||
logger.error(f"][ child_hooks Manual RPC Injection failed: {e}")
|
||||
_setup_proxy_callers()
|
||||
|
||||
_setup_logging()
|
||||
|
||||
@ -354,6 +354,16 @@ async def load_isolated_node(
|
||||
"sandbox": sandbox_config,
|
||||
}
|
||||
|
||||
share_torch_no_deps = tool_config.get("share_torch_no_deps", [])
|
||||
if share_torch_no_deps:
|
||||
if not isinstance(share_torch_no_deps, list) or not all(
|
||||
isinstance(dep, str) and dep.strip() for dep in share_torch_no_deps
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.share_torch_no_deps] must be a list of non-empty strings"
|
||||
)
|
||||
extension_config["share_torch_no_deps"] = share_torch_no_deps
|
||||
|
||||
_is_sealed = execution_model == "sealed_worker"
|
||||
_is_sandboxed = host_policy["sandbox_mode"] != "disabled" and is_linux
|
||||
logger.info(
|
||||
@ -367,6 +377,16 @@ async def load_isolated_node(
|
||||
if cuda_wheels is not None:
|
||||
extension_config["cuda_wheels"] = cuda_wheels
|
||||
|
||||
extra_index_urls = tool_config.get("extra_index_urls", [])
|
||||
if extra_index_urls:
|
||||
if not isinstance(extra_index_urls, list) or not all(
|
||||
isinstance(u, str) and u.strip() for u in extra_index_urls
|
||||
):
|
||||
raise ExtensionLoadError(
|
||||
"[tool.comfy.isolation.extra_index_urls] must be a list of non-empty strings"
|
||||
)
|
||||
extension_config["extra_index_urls"] = extra_index_urls
|
||||
|
||||
# Conda-specific keys
|
||||
if is_conda:
|
||||
extension_config["package_manager"] = "conda"
|
||||
@ -408,31 +428,17 @@ async def load_isolated_node(
|
||||
cache.register_proxy(extension_name, WebDirectoryProxy())
|
||||
|
||||
# Try cache first (lazy spawn)
|
||||
logger.warning("][ DIAG:ext_loader cache_valid_check for %s", extension_name)
|
||||
if is_cache_valid(node_dir, manifest_path, venv_root):
|
||||
cached_data = load_from_cache(node_dir, venv_root)
|
||||
if cached_data:
|
||||
if _is_stale_node_cache(cached_data):
|
||||
logger.warning(
|
||||
"][ DIAG:ext_loader %s cache is stale/incompatible; rebuilding metadata",
|
||||
extension_name,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
logger.warning("][ DIAG:ext_loader %s USING CACHE — dumping combo options:", extension_name)
|
||||
for node_name, details in cached_data.items():
|
||||
schema_v1 = details.get("schema_v1", {})
|
||||
inp = schema_v1.get("input", {}) if schema_v1 else {}
|
||||
for section_name, section in inp.items():
|
||||
if isinstance(section, dict):
|
||||
for field_name, field_def in section.items():
|
||||
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
|
||||
opts = field_def[1]["options"]
|
||||
logger.warning(
|
||||
"][ DIAG:ext_loader CACHE %s.%s.%s options=%d first=%s",
|
||||
node_name, section_name, field_name,
|
||||
len(opts),
|
||||
opts[:3] if opts else "EMPTY",
|
||||
)
|
||||
try:
|
||||
flushed = await extension.flush_pending_routes()
|
||||
logger.info("][ %s flushed %d routes", extension_name, flushed)
|
||||
except Exception as exc:
|
||||
logger.warning("][ %s route flush failed: %s", extension_name, exc)
|
||||
specs: List[Tuple[str, str, type]] = []
|
||||
for node_name, details in cached_data.items():
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
@ -440,11 +446,7 @@ async def load_isolated_node(
|
||||
(node_name, details.get("display_name", node_name), stub_cls)
|
||||
)
|
||||
return specs
|
||||
else:
|
||||
logger.warning("][ DIAG:ext_loader %s cache INVALID or MISSING", extension_name)
|
||||
|
||||
# Cache miss - spawn process and get metadata
|
||||
logger.warning("][ DIAG:ext_loader %s cache miss, spawning process for metadata", extension_name)
|
||||
|
||||
try:
|
||||
remote_nodes: Dict[str, str] = await extension.list_nodes()
|
||||
@ -466,7 +468,6 @@ async def load_isolated_node(
|
||||
cache_data: Dict[str, Dict] = {}
|
||||
|
||||
for node_name, display_name in remote_nodes.items():
|
||||
logger.warning("][ DIAG:ext_loader calling get_node_details for %s.%s", extension_name, node_name)
|
||||
try:
|
||||
details = await extension.get_node_details(node_name)
|
||||
except Exception as exc:
|
||||
@ -477,20 +478,6 @@ async def load_isolated_node(
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
# DIAG: dump combo options from freshly-fetched details
|
||||
schema_v1 = details.get("schema_v1", {})
|
||||
inp = schema_v1.get("input", {}) if schema_v1 else {}
|
||||
for section_name, section in inp.items():
|
||||
if isinstance(section, dict):
|
||||
for field_name, field_def in section.items():
|
||||
if isinstance(field_def, (list, tuple)) and len(field_def) >= 2 and isinstance(field_def[1], dict) and "options" in field_def[1]:
|
||||
opts = field_def[1]["options"]
|
||||
logger.warning(
|
||||
"][ DIAG:ext_loader FRESH %s.%s.%s options=%d first=%s",
|
||||
node_name, section_name, field_name,
|
||||
len(opts),
|
||||
opts[:3] if opts else "EMPTY",
|
||||
)
|
||||
details["display_name"] = display_name
|
||||
cache_data[node_name] = details
|
||||
stub_cls = build_stub_class(node_name, details, extension)
|
||||
@ -512,6 +499,14 @@ async def load_isolated_node(
|
||||
if host_policy["sandbox_mode"] == "disabled":
|
||||
_register_web_directory(extension_name, node_dir)
|
||||
|
||||
# Flush any routes the child buffered during module import — must happen
|
||||
# before router freeze and before we kill the child process.
|
||||
try:
|
||||
flushed = await extension.flush_pending_routes()
|
||||
logger.info("][ %s flushed %d routes", extension_name, flushed)
|
||||
except Exception as exc:
|
||||
logger.warning("][ %s route flush failed: %s", extension_name, exc)
|
||||
|
||||
# EJECT: Kill process after getting metadata (will respawn on first execution)
|
||||
await _stop_extension_safe(extension, extension_name)
|
||||
|
||||
|
||||
@ -211,6 +211,7 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
|
||||
self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {}
|
||||
self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {}
|
||||
self._register_module_routes(module)
|
||||
|
||||
# Register web directory with WebDirectoryProxy (child-side)
|
||||
web_dir_attr = getattr(module, "WEB_DIRECTORY", None)
|
||||
@ -280,6 +281,55 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
|
||||
self.node_instances = {}
|
||||
|
||||
def _register_module_routes(self, module: Any) -> None:
|
||||
"""Bridge legacy module-level ROUTES declarations into isolated routing."""
|
||||
routes = getattr(module, "ROUTES", None) or []
|
||||
if not routes:
|
||||
return
|
||||
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
prompt_server = PromptServerStub()
|
||||
route_table = getattr(prompt_server, "routes", None)
|
||||
if route_table is None:
|
||||
logger.warning("%s Route registration unavailable for %s", LOG_PREFIX, module)
|
||||
return
|
||||
|
||||
for route_spec in routes:
|
||||
if not isinstance(route_spec, dict):
|
||||
logger.warning("%s Ignoring non-dict ROUTES entry: %r", LOG_PREFIX, route_spec)
|
||||
continue
|
||||
|
||||
method = str(route_spec.get("method", "")).strip().upper()
|
||||
path = str(route_spec.get("path", "")).strip()
|
||||
handler_ref = route_spec.get("handler")
|
||||
if not method or not path:
|
||||
logger.warning("%s Ignoring incomplete route spec: %r", LOG_PREFIX, route_spec)
|
||||
continue
|
||||
|
||||
if isinstance(handler_ref, str):
|
||||
handler = getattr(module, handler_ref, None)
|
||||
else:
|
||||
handler = handler_ref
|
||||
if not callable(handler):
|
||||
logger.warning(
|
||||
"%s Ignoring route with missing handler %r for %s %s",
|
||||
LOG_PREFIX,
|
||||
handler_ref,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
continue
|
||||
|
||||
decorator = getattr(route_table, method.lower(), None)
|
||||
if not callable(decorator):
|
||||
logger.warning("%s Unsupported route method %s for %s", LOG_PREFIX, method, path)
|
||||
continue
|
||||
|
||||
decorator(path)(handler)
|
||||
self._route_handlers[f"{method} {path}"] = handler
|
||||
logger.info("%s buffered legacy route %s %s", LOG_PREFIX, method, path)
|
||||
|
||||
async def list_nodes(self) -> Dict[str, str]:
|
||||
return {name: self.display_names.get(name, name) for name in self.node_classes}
|
||||
|
||||
@ -289,10 +339,6 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
async def get_node_details(self, node_name: str) -> Dict[str, Any]:
|
||||
node_cls = self._get_node_class(node_name)
|
||||
is_v3 = issubclass(node_cls, _ComfyNodeInternal)
|
||||
logger.warning(
|
||||
"%s DIAG:get_node_details START | node=%s | is_v3=%s | cls=%s",
|
||||
LOG_PREFIX, node_name, is_v3, node_cls,
|
||||
)
|
||||
|
||||
input_types_raw = (
|
||||
node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {}
|
||||
@ -316,16 +362,7 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
|
||||
if is_v3:
|
||||
try:
|
||||
logger.warning(
|
||||
"%s DIAG:get_node_details calling GET_SCHEMA for %s",
|
||||
LOG_PREFIX, node_name,
|
||||
)
|
||||
schema = node_cls.GET_SCHEMA()
|
||||
logger.warning(
|
||||
"%s DIAG:get_node_details GET_SCHEMA returned for %s | schema_inputs=%s",
|
||||
LOG_PREFIX, node_name,
|
||||
[getattr(i, 'id', '?') for i in (schema.inputs or [])],
|
||||
)
|
||||
schema_v1 = asdict(schema.get_v1_info(node_cls))
|
||||
try:
|
||||
schema_v3 = asdict(schema.get_v3_info(node_cls))
|
||||
@ -532,6 +569,11 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
wrapped = self._wrap_unpicklable_objects(result)
|
||||
return wrapped
|
||||
|
||||
async def flush_pending_routes(self) -> int:
|
||||
"""Flush buffered route registrations to host via RPC. Called by host after node discovery."""
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
return await PromptServerStub.flush_child_routes()
|
||||
|
||||
async def flush_transport_state(self) -> int:
|
||||
if os.environ.get("PYISOLATE_CHILD") != "1":
|
||||
return 0
|
||||
@ -750,19 +792,13 @@ class ComfyNodeExtension(ExtensionBase):
|
||||
return self.node_instances[node_name]
|
||||
|
||||
async def before_module_loaded(self) -> None:
|
||||
# Inject initialization here if we think this is the child
|
||||
logger.warning(
|
||||
"%s DIAG:before_module_loaded START | is_child=%s",
|
||||
LOG_PREFIX, os.environ.get("PYISOLATE_CHILD"),
|
||||
)
|
||||
try:
|
||||
from comfy.isolation import initialize_proxies
|
||||
|
||||
initialize_proxies()
|
||||
logger.warning("%s DIAG:before_module_loaded initialize_proxies OK", LOG_PREFIX)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s DIAG:before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
|
||||
"%s before_module_loaded initialize_proxies FAILED: %s", LOG_PREFIX, e
|
||||
)
|
||||
|
||||
await super().before_module_loaded()
|
||||
|
||||
@ -166,6 +166,8 @@ def load_host_policy(comfy_root: Path) -> HostSecurityPolicy:
|
||||
if isinstance(whitelist_raw, dict):
|
||||
policy["whitelist"].update({str(k): str(v) for k, v in whitelist_raw.items()})
|
||||
|
||||
os.environ["PYISOLATE_SANDBOX_MODE"] = policy["sandbox_mode"]
|
||||
|
||||
logger.debug(
|
||||
"Loaded Host Policy: %d whitelisted nodes, Sandbox=%s, Network=%s",
|
||||
len(policy["whitelist"]),
|
||||
|
||||
@ -885,4 +885,6 @@ class _InnerModelProxy:
|
||||
return lambda *a, **k: self._parent._call_rpc("scale_latent_inpaint", a, k)
|
||||
if name == "diffusion_model":
|
||||
return self._parent._call_rpc("get_inner_model_attr", "diffusion_model")
|
||||
if name == "state_dict":
|
||||
return lambda: self._parent.model_state_dict()
|
||||
raise AttributeError(f"'{name}' not supported on isolated InnerModel")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pyisolate import ProxiedSingleton
|
||||
@ -152,24 +152,9 @@ class FolderPathsProxy(ProxiedSingleton):
|
||||
return list(_folder_paths().get_folder_paths(folder_name))
|
||||
|
||||
def get_filename_list(self, folder_name: str) -> list[str]:
|
||||
caller_stack = "".join(traceback.format_stack()[-4:-1])
|
||||
_fp_logger.warning(
|
||||
"][ DIAG:FolderPathsProxy.get_filename_list called | folder=%s | is_child=%s | rpc_configured=%s\n%s",
|
||||
folder_name, _is_child_process(), self._rpc is not None, caller_stack,
|
||||
)
|
||||
if _is_child_process():
|
||||
result = list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
|
||||
_fp_logger.warning(
|
||||
"][ DIAG:FolderPathsProxy.get_filename_list RPC result | folder=%s | count=%d | first=%s",
|
||||
folder_name, len(result), result[:3] if result else "EMPTY",
|
||||
)
|
||||
return result
|
||||
result = list(_folder_paths().get_filename_list(folder_name))
|
||||
_fp_logger.warning(
|
||||
"][ DIAG:FolderPathsProxy.get_filename_list LOCAL result | folder=%s | count=%d | first=%s",
|
||||
folder_name, len(result), result[:3] if result else "EMPTY",
|
||||
)
|
||||
return result
|
||||
return list(call_singleton_rpc(self._get_caller(), "rpc_get_filename_list", folder_name))
|
||||
return list(_folder_paths().get_filename_list(folder_name))
|
||||
|
||||
def get_full_path(self, folder_name: str, filename: str) -> str | None:
|
||||
if _is_child_process():
|
||||
|
||||
@ -8,7 +8,6 @@ Replaces the legacy PromptServerProxy (Singleton) with a clean Service/Stub arch
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
@ -94,14 +93,13 @@ class PromptServerStub:
|
||||
def client_id(self) -> Optional[str]:
|
||||
return "isolated_client"
|
||||
|
||||
def supports(self, feature: str) -> bool:
|
||||
return True
|
||||
@property
|
||||
def supports(self) -> set:
|
||||
return {"custom_nodes_from_web"}
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
raise RuntimeError(
|
||||
"PromptServer.app is not accessible in isolated nodes. Use RPC routes instead."
|
||||
)
|
||||
return _AppStub(self)
|
||||
|
||||
@property
|
||||
def prompt_queue(self):
|
||||
@ -140,18 +138,27 @@ class PromptServerStub:
|
||||
call_singleton_rpc(self._rpc, "ui_send_progress_text", text, node_id, sid)
|
||||
|
||||
# --- Route Registration Logic ---
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Register a route handler via RPC."""
|
||||
if not self._rpc:
|
||||
logger.error("RPC not initialized in PromptServerStub")
|
||||
return
|
||||
_pending_child_routes: list = []
|
||||
|
||||
# Fire registration async
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._rpc.register_route_rpc(method, path, handler))
|
||||
except RuntimeError:
|
||||
call_singleton_rpc(self._rpc, "register_route_rpc", method, path, handler)
|
||||
def register_route(self, method: str, path: str, handler: Callable):
|
||||
"""Buffer route registration. Routes are flushed via flush_child_routes()."""
|
||||
PromptServerStub._pending_child_routes.append((method, path, handler))
|
||||
logger.info("%s Buffered isolated route %s %s", LOG_PREFIX, method, path)
|
||||
|
||||
@classmethod
|
||||
async def flush_child_routes(cls):
|
||||
"""Send all buffered route registrations to host via RPC. Call from on_module_loaded."""
|
||||
if not cls._rpc:
|
||||
return 0
|
||||
flushed = 0
|
||||
for method, path, handler in cls._pending_child_routes:
|
||||
try:
|
||||
await cls._rpc.register_route_rpc(method, path, handler)
|
||||
flushed += 1
|
||||
except Exception as e:
|
||||
logger.error("%s Child route flush failed %s %s: %s", LOG_PREFIX, method, path, e)
|
||||
cls._pending_child_routes = []
|
||||
return flushed
|
||||
|
||||
|
||||
class RouteStub:
|
||||
@ -205,7 +212,6 @@ class PromptServerService(ProxiedSingleton):
|
||||
"""Host-side RPC Service for PromptServer."""
|
||||
|
||||
def __init__(self):
|
||||
# We will bind to the real server instance lazily or via global import
|
||||
pass
|
||||
|
||||
@property
|
||||
@ -231,7 +237,7 @@ class PromptServerService(ProxiedSingleton):
|
||||
async def register_route_rpc(self, method: str, path: str, child_handler_proxy):
|
||||
"""RPC Target: Register a route that forwards to the Child."""
|
||||
from aiohttp import web
|
||||
logger.debug(f"{LOG_PREFIX} Registering Isolated Route {method} {path}")
|
||||
logger.info("%s Registering isolated route %s %s", LOG_PREFIX, method, path)
|
||||
|
||||
async def route_wrapper(request: web.Request) -> web.Response:
|
||||
# 1. Capture request data
|
||||
@ -253,8 +259,8 @@ class PromptServerService(ProxiedSingleton):
|
||||
logger.error(f"{LOG_PREFIX} Isolated Route Error: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
# Register loop
|
||||
self.server.app.router.add_route(method, path, route_wrapper)
|
||||
logger.info("%s Registered isolated route %s %s", LOG_PREFIX, method, path)
|
||||
|
||||
def _serialize_response(self, result: Any) -> Any:
|
||||
"""Helper to convert Child result -> web.Response"""
|
||||
@ -269,3 +275,32 @@ class PromptServerService(ProxiedSingleton):
|
||||
return web.Response(text=result)
|
||||
# Fallback
|
||||
return web.Response(text=str(result))
|
||||
|
||||
|
||||
class _RouterStub:
|
||||
"""Captures router.add_route and router.add_static calls in isolation child."""
|
||||
|
||||
def __init__(self, stub):
|
||||
self._stub = stub
|
||||
|
||||
def add_route(self, method, path, handler, **kwargs):
|
||||
self._stub.register_route(method, path, handler)
|
||||
|
||||
def add_static(self, prefix, path, **kwargs):
|
||||
# Static file serving not supported in isolation — silently skip
|
||||
pass
|
||||
|
||||
|
||||
class _AppStub:
|
||||
"""Captures PromptServer.app access patterns in isolation child."""
|
||||
|
||||
def __init__(self, stub):
|
||||
self.router = _RouterStub(stub)
|
||||
self.frozen = False
|
||||
|
||||
def add_routes(self, routes):
|
||||
# aiohttp route table — iterate and register each
|
||||
for route in routes:
|
||||
if hasattr(route, "method") and hasattr(route, "handler"):
|
||||
self.router.add_route(route.method, route.path, route.handler)
|
||||
# StaticDef and other non-method routes — silently skip
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Drop-in replacement for comfy_api.latest._util type imports in sealed workers
|
||||
that do not have torch installed. Contains only data type definitions (TrimeshData,
|
||||
PLY, NPZ, etc.) with numpy-only dependencies.
|
||||
etc.) with numpy-only dependencies.
|
||||
|
||||
Usage in serializers:
|
||||
if _IMPORT_TORCH:
|
||||
@ -12,7 +12,5 @@ Usage in serializers:
|
||||
"""
|
||||
|
||||
from .trimesh_types import TrimeshData
|
||||
from .ply_types import PLY
|
||||
from .npz_types import NPZ
|
||||
|
||||
__all__ = ["TrimeshData", "PLY", "NPZ"]
|
||||
__all__ = ["TrimeshData"]
|
||||
|
||||
@ -1,27 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class NPZ:
|
||||
"""Ordered collection of NPZ file payloads.
|
||||
|
||||
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
|
||||
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
|
||||
``save_to`` writes numbered files into a directory.
|
||||
"""
|
||||
|
||||
def __init__(self, frames: list[bytes]) -> None:
|
||||
self.frames = frames
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return len(self.frames)
|
||||
|
||||
def save_to(self, directory: str, prefix: str = "frame") -> str:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for i, frame_bytes in enumerate(self.frames):
|
||||
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
|
||||
with open(path, "wb") as f:
|
||||
f.write(frame_bytes)
|
||||
return directory
|
||||
@ -1,97 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PLY:
|
||||
"""Point cloud payload for PLY file output.
|
||||
|
||||
Supports two schemas:
|
||||
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
|
||||
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
|
||||
|
||||
When ``raw_data`` is provided, the object acts as an opaque binary PLY
|
||||
carrier and ``save_to`` writes the bytes directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
points: np.ndarray | None = None,
|
||||
colors: np.ndarray | None = None,
|
||||
confidence: np.ndarray | None = None,
|
||||
view_id: np.ndarray | None = None,
|
||||
raw_data: bytes | None = None,
|
||||
) -> None:
|
||||
self.raw_data = raw_data
|
||||
if raw_data is not None:
|
||||
self.points = None
|
||||
self.colors = None
|
||||
self.confidence = None
|
||||
self.view_id = None
|
||||
return
|
||||
if points is None:
|
||||
raise ValueError("Either points or raw_data must be provided")
|
||||
if points.ndim != 2 or points.shape[1] != 3:
|
||||
raise ValueError(f"points must be (N, 3), got {points.shape}")
|
||||
self.points = np.ascontiguousarray(points, dtype=np.float32)
|
||||
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
|
||||
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
|
||||
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
|
||||
|
||||
@property
|
||||
def is_gaussian(self) -> bool:
|
||||
return self.raw_data is not None
|
||||
|
||||
@property
|
||||
def num_points(self) -> int:
|
||||
if self.points is not None:
|
||||
return self.points.shape[0]
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _to_numpy(arr, dtype):
|
||||
if arr is None:
|
||||
return None
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
|
||||
return np.ascontiguousarray(arr, dtype=dtype)
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
if self.raw_data is not None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(self.raw_data)
|
||||
return path
|
||||
self.points = self._to_numpy(self.points, np.float32)
|
||||
self.colors = self._to_numpy(self.colors, np.float32)
|
||||
self.confidence = self._to_numpy(self.confidence, np.float32)
|
||||
self.view_id = self._to_numpy(self.view_id, np.int32)
|
||||
N = self.num_points
|
||||
header_lines = [
|
||||
"ply",
|
||||
"format ascii 1.0",
|
||||
f"element vertex {N}",
|
||||
"property float x",
|
||||
"property float y",
|
||||
"property float z",
|
||||
]
|
||||
if self.colors is not None:
|
||||
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
|
||||
if self.confidence is not None:
|
||||
header_lines.append("property float confidence")
|
||||
if self.view_id is not None:
|
||||
header_lines.append("property int view_id")
|
||||
header_lines.append("end_header")
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(header_lines) + "\n")
|
||||
for i in range(N):
|
||||
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
|
||||
if self.colors is not None:
|
||||
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
|
||||
parts.append(f"{r} {g} {b}")
|
||||
if self.confidence is not None:
|
||||
parts.append(f"{self.confidence[i]}")
|
||||
if self.view_id is not None:
|
||||
parts.append(f"{int(self.view_id[i])}")
|
||||
f.write(" ".join(parts) + "\n")
|
||||
return path
|
||||
@ -1,40 +0,0 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api_sealed_worker.npz_types import NPZ
|
||||
|
||||
|
||||
class SaveNPZ(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveNPZ",
|
||||
display_name="Save NPZ",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Npz.Input("npz"),
|
||||
io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}")
|
||||
os.makedirs(batch_dir, exist_ok=True)
|
||||
filenames = []
|
||||
for i, frame_bytes in enumerate(npz.frames):
|
||||
f = f"frame_{i:06d}.npz"
|
||||
with open(os.path.join(batch_dir, f), "wb") as fh:
|
||||
fh.write(frame_bytes)
|
||||
filenames.append(f)
|
||||
return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SaveNPZ": SaveNPZ,
|
||||
}
|
||||
@ -1,34 +0,0 @@
|
||||
import os
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import io
|
||||
from comfy_api_sealed_worker.ply_types import PLY
|
||||
|
||||
|
||||
class SavePLY(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SavePLY",
|
||||
display_name="Save PLY",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.Ply.Input("ply"),
|
||||
io.String.Input("filename_prefix", default="pointcloud/ComfyUI"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
filename_prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
f = f"{filename}_{counter:05}_.ply"
|
||||
ply.save_to(os.path.join(full_output_folder, f))
|
||||
return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SavePLY": SavePLY,
|
||||
}
|
||||
2
nodes.py
2
nodes.py
@ -2459,8 +2459,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_wan.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
"nodes_save_ply.py",
|
||||
"nodes_save_npz.py",
|
||||
"nodes_primitive.py",
|
||||
"nodes_cfg.py",
|
||||
"nodes_optimalsteps.py",
|
||||
|
||||
@ -35,5 +35,4 @@ pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
PyOpenGL
|
||||
glfw
|
||||
|
||||
pyisolate==0.10.1
|
||||
pyisolate==0.10.2
|
||||
|
||||
@ -724,6 +724,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
|
||||
from comfy.isolation.proxies.web_directory_proxy import WebDirectoryCache
|
||||
|
||||
PromptServerStub.set_rpc(fake_rpc)
|
||||
PromptServerStub._pending_child_routes = []
|
||||
stub = PromptServerStub()
|
||||
cache = WebDirectoryCache()
|
||||
cache.register_proxy("demo_ext", FakeWebDirectoryProxy(fake_rpc.transcripts))
|
||||
@ -735,6 +736,7 @@ def capture_prompt_web_exact_relay() -> dict[str, object]:
|
||||
|
||||
stub.send_progress_text("hello", "node-17")
|
||||
stub.routes.get("/demo")(demo_handler)
|
||||
asyncio.run(PromptServerStub.flush_child_routes())
|
||||
web_file = cache.get_file("demo_ext", "js/app.js")
|
||||
imported = set(sys.modules) - before
|
||||
return {
|
||||
|
||||
@ -108,6 +108,119 @@ flash_attn = "flash-attn-special"
|
||||
}
|
||||
|
||||
|
||||
def test_load_isolated_node_passes_share_torch_no_deps(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["timm", "pyyaml"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = true
|
||||
share_torch_no_deps = ["timm"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "disabled",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
specs = asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(specs) == 1
|
||||
assert captured["share_torch_no_deps"] == ["timm"]
|
||||
|
||||
|
||||
def test_on_module_loaded_registers_legacy_routes(monkeypatch):
|
||||
captured: list[tuple[str, str, Any]] = []
|
||||
|
||||
def demo_handler(body):
|
||||
return body
|
||||
|
||||
module = SimpleNamespace(
|
||||
__file__="/tmp/demo_node/__init__.py",
|
||||
__name__="demo_node",
|
||||
NODE_CLASS_MAPPINGS={},
|
||||
NODE_DISPLAY_NAME_MAPPINGS={},
|
||||
ROUTES=[
|
||||
{"method": "POST", "path": "/sam3/interactive_segment_one", "handler": "demo_handler"},
|
||||
],
|
||||
demo_handler=demo_handler,
|
||||
)
|
||||
|
||||
def fake_register_route(self, method, path, handler):
|
||||
captured.append((method, path, handler))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"comfy.isolation.proxies.prompt_server_impl.PromptServerStub.register_route",
|
||||
fake_register_route,
|
||||
)
|
||||
|
||||
extension = ComfyNodeExtension()
|
||||
asyncio.run(extension.on_module_loaded(module))
|
||||
|
||||
assert captured == [("POST", "/sam3/interactive_segment_one", demo_handler)]
|
||||
|
||||
|
||||
def test_prompt_server_stub_buffers_routes_without_rpc():
|
||||
from comfy.isolation.proxies.prompt_server_impl import PromptServerStub
|
||||
|
||||
def demo_handler(body):
|
||||
return body
|
||||
|
||||
old_rpc = PromptServerStub._rpc
|
||||
old_pending = list(PromptServerStub._pending_child_routes)
|
||||
try:
|
||||
PromptServerStub._rpc = None
|
||||
PromptServerStub._pending_child_routes = []
|
||||
PromptServerStub().register_route("POST", "/sam3/interactive_segment_one", demo_handler)
|
||||
assert PromptServerStub._pending_child_routes == [
|
||||
("POST", "/sam3/interactive_segment_one", demo_handler)
|
||||
]
|
||||
finally:
|
||||
PromptServerStub._rpc = old_rpc
|
||||
PromptServerStub._pending_child_routes = old_pending
|
||||
|
||||
|
||||
def test_load_isolated_node_rejects_undeclared_cuda_wheel_dependency(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
@ -362,6 +475,70 @@ can_isolate = true
|
||||
assert "cuda_wheels" not in captured
|
||||
|
||||
|
||||
def test_load_isolated_node_passes_extra_index_urls(tmp_path, monkeypatch):
|
||||
node_dir = tmp_path / "node"
|
||||
node_dir.mkdir()
|
||||
manifest_path = node_dir / "pyproject.toml"
|
||||
_write_manifest(
|
||||
node_dir,
|
||||
"""
|
||||
[project]
|
||||
name = "demo-node"
|
||||
dependencies = ["fbxsdkpy==2020.1.post2", "numpy>=1.0"]
|
||||
|
||||
[tool.comfy.isolation]
|
||||
can_isolate = true
|
||||
share_torch = true
|
||||
extra_index_urls = ["https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"]
|
||||
""".strip(),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyManager:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
def load_extension(self, config):
|
||||
captured.update(config)
|
||||
return _DummyExtension()
|
||||
|
||||
monkeypatch.setattr(extension_loader_module.pyisolate, "ExtensionManager", DummyManager)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_host_policy",
|
||||
lambda base_path: {
|
||||
"sandbox_mode": "disabled",
|
||||
"allow_network": False,
|
||||
"writable_paths": [],
|
||||
"readonly_paths": [],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(extension_loader_module, "is_cache_valid", lambda *args, **kwargs: True)
|
||||
monkeypatch.setattr(
|
||||
extension_loader_module,
|
||||
"load_from_cache",
|
||||
lambda *args, **kwargs: {"Node": {"display_name": "Node", "schema_v1": {}}},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "folder_paths", SimpleNamespace(base_path=str(tmp_path)))
|
||||
|
||||
specs = asyncio.run(
|
||||
load_isolated_node(
|
||||
node_dir,
|
||||
manifest_path,
|
||||
logging.getLogger("test"),
|
||||
lambda *args, **kwargs: object,
|
||||
tmp_path / "venvs",
|
||||
[],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(specs) == 1
|
||||
assert captured["extra_index_urls"] == [
|
||||
"https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple"
|
||||
]
|
||||
|
||||
|
||||
def test_maybe_wrap_model_for_isolation_uses_runtime_flag(monkeypatch):
|
||||
class DummyRegistry:
|
||||
def register(self, model):
|
||||
|
||||
45
tests/isolation/test_inner_model_state_dict.py
Normal file
45
tests/isolation/test_inner_model_state_dict.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Test that _InnerModelProxy exposes state_dict for LoRA loading."""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
pyisolate_root = repo_root.parent / "pyisolate"
|
||||
if pyisolate_root.exists():
|
||||
sys.path.insert(0, str(pyisolate_root))
|
||||
|
||||
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
|
||||
|
||||
|
||||
def test_inner_model_proxy_state_dict_returns_keys():
|
||||
"""_InnerModelProxy.state_dict() delegates to parent.model_state_dict()."""
|
||||
proxy = object.__new__(ModelPatcherProxy)
|
||||
proxy._model_id = "test_model"
|
||||
proxy._rpc = MagicMock()
|
||||
proxy._model_type_name = "SDXL"
|
||||
proxy._inner_model_channels = None
|
||||
|
||||
fake_keys = ["diffusion_model.input.weight", "diffusion_model.output.weight"]
|
||||
proxy._call_rpc = MagicMock(return_value=fake_keys)
|
||||
|
||||
inner = proxy.model
|
||||
sd = inner.state_dict()
|
||||
|
||||
assert isinstance(sd, dict)
|
||||
assert "diffusion_model.input.weight" in sd
|
||||
assert "diffusion_model.output.weight" in sd
|
||||
proxy._call_rpc.assert_called_with("model_state_dict", None)
|
||||
|
||||
|
||||
def test_inner_model_proxy_state_dict_callable():
|
||||
"""state_dict is a callable, not a property — matches torch.nn.Module interface."""
|
||||
proxy = object.__new__(ModelPatcherProxy)
|
||||
proxy._model_id = "test_model"
|
||||
proxy._rpc = MagicMock()
|
||||
proxy._model_type_name = "SDXL"
|
||||
proxy._inner_model_channels = None
|
||||
|
||||
proxy._call_rpc = MagicMock(return_value=[])
|
||||
|
||||
inner = proxy.model
|
||||
assert callable(inner.state_dict)
|
||||
Loading…
Reference in New Issue
Block a user