mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Compare commits
5 Commits
fd9c93e77e
...
6b57b80fec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b57b80fec | ||
|
|
cd8c7a2306 | ||
|
|
6bcd8b96ab | ||
|
|
97e84d399c | ||
|
|
2ff1d3d042 |
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
"""Register a node replacement mapping.
|
||||
|
||||
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||
is already registered, the duplicate is ignored. This prevents stale
|
||||
entries from accumulating when custom nodes are reloaded in the same
|
||||
process (e.g. via ComfyUI-Manager).
|
||||
"""
|
||||
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||
for entry in existing:
|
||||
if entry.new_node_id == node_replace.new_node_id:
|
||||
logging.debug(
|
||||
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||
node_replace.old_node_id, node_replace.new_node_id,
|
||||
)
|
||||
return
|
||||
existing.append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
|
||||
@ -26,6 +26,7 @@ import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import comfy.float
|
||||
import comfy.hooks
|
||||
@ -1651,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory)
|
||||
in_loop = bool(getattr(tqdm.tqdm, "_instances", None))
|
||||
level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO
|
||||
self._last_prepare_log_key = log_key
|
||||
logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
@ -560,7 +560,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
@ -580,7 +580,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
@ -597,7 +597,7 @@ class PromptServer():
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
@ -614,7 +614,7 @@ class PromptServer():
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{filename}\"",
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
|
||||
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Tests for NodeReplaceManager registration behavior."""
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def NodeReplaceManager(monkeypatch):
|
||||
"""Provide NodeReplaceManager with `nodes` stubbed.
|
||||
|
||||
`app.node_replace_manager` does `import nodes` at module level, which pulls in
|
||||
torch + the full ComfyUI graph. register() doesn't actually need it, so we
|
||||
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
|
||||
module so it picks up the stub instead of any cached real import.
|
||||
"""
|
||||
fake_nodes = types.ModuleType("nodes")
|
||||
fake_nodes.NODE_CLASS_MAPPINGS = {}
|
||||
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
|
||||
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
|
||||
module = importlib.import_module("app.node_replace_manager")
|
||||
yield module.NodeReplaceManager
|
||||
# Drop the freshly-imported module so the next test (or a later real import
|
||||
# of `nodes`) starts from a clean slate.
|
||||
sys.modules.pop("app.node_replace_manager", None)
|
||||
|
||||
|
||||
class FakeNodeReplace:
|
||||
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
|
||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||
input_mapping=None, output_mapping=None):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
|
||||
def test_register_adds_replacement(NodeReplaceManager):
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
assert manager.has_replacement("OldNode")
|
||||
assert len(manager.get_replacement("OldNode")) == 1
|
||||
|
||||
|
||||
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
|
||||
"""Different new_node_ids for the same old_node_id should all be kept."""
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
|
||||
replacements = manager.get_replacement("OldNode")
|
||||
assert len(replacements) == 2
|
||||
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
|
||||
|
||||
|
||||
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
|
||||
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||
assert len(manager.get_replacement("OldNode")) == 1
|
||||
|
||||
|
||||
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
|
||||
"""First registration wins; later duplicates with different mappings are ignored."""
|
||||
manager = NodeReplaceManager()
|
||||
first = FakeNodeReplace(
|
||||
new_node_id="NewNode", old_node_id="OldNode",
|
||||
input_mapping=[{"new_id": "a", "old_id": "x"}],
|
||||
)
|
||||
second = FakeNodeReplace(
|
||||
new_node_id="NewNode", old_node_id="OldNode",
|
||||
input_mapping=[{"new_id": "b", "old_id": "y"}],
|
||||
)
|
||||
manager.register(first)
|
||||
manager.register(second)
|
||||
replacements = manager.get_replacement("OldNode")
|
||||
assert len(replacements) == 1
|
||||
assert replacements[0] is first
|
||||
|
||||
|
||||
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
|
||||
manager = NodeReplaceManager()
|
||||
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
|
||||
assert len(manager.get_replacement("OldA")) == 1
|
||||
assert len(manager.get_replacement("OldB")) == 1
|
||||
Loading…
Reference in New Issue
Block a user