mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-25 08:27:25 +08:00
Compare commits
15 Commits
bc1a1f8833
...
db3bd1eb0f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db3bd1eb0f | ||
|
|
5538f62b0b | ||
|
|
2806163f6e | ||
|
|
cea8d0925f | ||
|
|
b138133ffa | ||
|
|
5225f109a6 | ||
|
|
e35fe5bc09 | ||
|
|
77054cd49e | ||
|
|
1cd2730b25 | ||
|
|
d4351f77f8 | ||
|
|
9837dd368a | ||
|
|
62ec9a3238 | ||
|
|
b20cb7892e | ||
|
|
b9b24d425b | ||
|
|
d731cb6ae1 |
@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -7,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
@ -43,6 +46,7 @@ class NodeReplaceManager:
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
import nodes
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
@ -94,6 +98,60 @@ class NodeReplaceManager:
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None):
|
||||
"""Load node_replacements.json from a custom node directory and register replacements.
|
||||
|
||||
Custom node authors can ship a node_replacements.json file in their repo root
|
||||
to define node replacements declaratively. The file format matches the output
|
||||
of NodeReplace.as_dict(), keyed by old_node_id.
|
||||
|
||||
Fail-open: all errors are logged and skipped so a malformed file never
|
||||
prevents the custom node from loading.
|
||||
"""
|
||||
replacements_path = os.path.join(module_dir, "node_replacements.json")
|
||||
if not os.path.isfile(replacements_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(replacements_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
|
||||
return
|
||||
|
||||
if _node_replace_class is None:
|
||||
from comfy_api.latest._io import NodeReplace
|
||||
_node_replace_class = NodeReplace
|
||||
|
||||
count = 0
|
||||
for old_node_id, replacements in data.items():
|
||||
if not isinstance(replacements, list):
|
||||
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
|
||||
continue
|
||||
for entry in replacements:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
new_node_id = entry.get("new_node_id", "")
|
||||
if not new_node_id:
|
||||
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
|
||||
continue
|
||||
self.register(_node_replace_class(
|
||||
new_node_id=new_node_id,
|
||||
old_node_id=entry.get("old_node_id", old_node_id),
|
||||
old_widget_ids=entry.get("old_widget_ids"),
|
||||
input_mapping=entry.get("input_mapping"),
|
||||
output_mapping=entry.get("output_mapping"),
|
||||
))
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
|
||||
@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -21,7 +23,15 @@ try:
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
display_name="Color Transfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
|
||||
@ -49,7 +49,7 @@ class Int(io.ComfyNode):
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
72
nodes.py
72
nodes.py
@ -1754,57 +1754,49 @@ class LoadImage:
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
|
||||
class LoadImageMask(LoadImage):
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
types = super().INPUT_TYPES()
|
||||
return {
|
||||
"required": {
|
||||
**types["required"],
|
||||
"channel": (s._color_channels, )
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = node_helpers.pillow(Image.open, image_path)
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
FUNCTION = "load_image_mask"
|
||||
|
||||
def load_image_mask(self, image, channel):
|
||||
image_tensor, mask_tensor = super().load_image(image)
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
|
||||
if c == 'A':
|
||||
return (mask_tensor,)
|
||||
|
||||
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||
|
||||
if channel_idx < image_tensor.shape[-1]:
|
||||
return (image_tensor[..., channel_idx].clone(),)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask.unsqueeze(0),)
|
||||
empty_mask = torch.zeros(
|
||||
image_tensor.shape[:-1],
|
||||
dtype=image_tensor.dtype,
|
||||
device=image_tensor.device
|
||||
)
|
||||
return (empty_mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
@ -2204,6 +2196,12 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
|
||||
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
|
||||
|
||||
# Only load node_replacements.json from directory-based custom nodes (proper packs).
|
||||
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
|
||||
if os.path.isdir(module_path):
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name)
|
||||
|
||||
try:
|
||||
from comfy_config import config_parser
|
||||
|
||||
|
||||
217
tests/test_node_replacements_json.py
Normal file
217
tests/test_node_replacements_json.py
Normal file
@ -0,0 +1,217 @@
|
||||
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
|
||||
node_replacements.json from custom node directories."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
|
||||
|
||||
class SimpleNodeReplace:
|
||||
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
|
||||
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||
input_mapping=None, output_mapping=None):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
class TestLoadFromJson(unittest.TestCase):
|
||||
"""Test auto-registration of node_replacements.json from custom node directories."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.manager = NodeReplaceManager()
|
||||
|
||||
def _write_json(self, data):
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _load(self):
|
||||
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
|
||||
|
||||
def test_no_file_does_nothing(self):
|
||||
"""No node_replacements.json — should silently do nothing."""
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_empty_object(self):
|
||||
"""Empty {} — should do nothing."""
|
||||
self._write_json({})
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_single_replacement(self):
|
||||
"""Single replacement entry registers correctly."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}],
|
||||
"output_mapping": [{"new_idx": 0, "old_idx": 0}],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertIn("OldNode", result)
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
entry = result["OldNode"][0]
|
||||
self.assertEqual(entry["new_node_id"], "NewNode")
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
|
||||
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
|
||||
|
||||
def test_multiple_replacements(self):
|
||||
"""Multiple old_node_ids each with entries."""
|
||||
self._write_json({
|
||||
"NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}],
|
||||
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertIn("NodeA", result)
|
||||
self.assertIn("NodeC", result)
|
||||
|
||||
def test_multiple_alternatives_for_same_node(self):
|
||||
"""Multiple replacement options for the same old node."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"new_node_id": "AltA", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "AltB", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 2)
|
||||
|
||||
def test_null_mappings(self):
|
||||
"""Null input/output mappings (trivial replacement)."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": None,
|
||||
"output_mapping": None,
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertIsNone(entry["input_mapping"])
|
||||
self.assertIsNone(entry["output_mapping"])
|
||||
|
||||
def test_old_node_id_defaults_to_key(self):
|
||||
"""If old_node_id is missing from entry, uses the dict key."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode"}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_node_id"], "OldNode")
|
||||
|
||||
def test_invalid_json_skips(self):
|
||||
"""Invalid JSON file — should warn and skip, not crash."""
|
||||
path = os.path.join(self.tmpdir, "node_replacements.json")
|
||||
with open(path, "w") as f:
|
||||
f.write("{invalid json")
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_object_json_skips(self):
|
||||
"""JSON array instead of object — should warn and skip."""
|
||||
self._write_json([1, 2, 3])
|
||||
self._load()
|
||||
self.assertEqual(self.manager.as_dict(), {})
|
||||
|
||||
def test_non_list_value_skips(self):
|
||||
"""Value is not a list — should warn and skip that key."""
|
||||
self._write_json({
|
||||
"OldNode": "not a list",
|
||||
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertNotIn("OldNode", result)
|
||||
self.assertIn("GoodNode", result)
|
||||
|
||||
def test_with_old_widget_ids(self):
|
||||
"""old_widget_ids are passed through."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"old_widget_ids": ["width", "height"],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
|
||||
|
||||
def test_set_value_in_input_mapping(self):
|
||||
"""input_mapping with set_value entries."""
|
||||
self._write_json({
|
||||
"OldNode": [{
|
||||
"new_node_id": "NewNode",
|
||||
"old_node_id": "OldNode",
|
||||
"input_mapping": [
|
||||
{"new_id": "method", "set_value": "lanczos"},
|
||||
{"new_id": "size", "old_id": "dimension"},
|
||||
],
|
||||
}]
|
||||
})
|
||||
self._load()
|
||||
entry = self.manager.as_dict()["OldNode"][0]
|
||||
self.assertEqual(len(entry["input_mapping"]), 2)
|
||||
|
||||
def test_missing_new_node_id_skipped(self):
|
||||
"""Entry without new_node_id is skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
{"old_node_id": "OldNode"},
|
||||
{"new_node_id": "", "old_node_id": "OldNode"},
|
||||
{"new_node_id": "ValidNew", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
|
||||
|
||||
def test_non_dict_entry_skipped(self):
|
||||
"""Non-dict entries in the list are silently skipped."""
|
||||
self._write_json({
|
||||
"OldNode": [
|
||||
"not a dict",
|
||||
{"new_node_id": "NewNode", "old_node_id": "OldNode"},
|
||||
]
|
||||
})
|
||||
self._load()
|
||||
result = self.manager.as_dict()
|
||||
self.assertEqual(len(result["OldNode"]), 1)
|
||||
|
||||
def test_has_replacement_after_load(self):
|
||||
"""Manager reports has_replacement correctly after JSON load."""
|
||||
self._write_json({
|
||||
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
|
||||
})
|
||||
self.assertFalse(self.manager.has_replacement("OldNode"))
|
||||
self._load()
|
||||
self.assertTrue(self.manager.has_replacement("OldNode"))
|
||||
self.assertFalse(self.manager.has_replacement("UnknownNode"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user