From 0d9232f02c19be38a2ae02f48fa114cfba936fdc Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Mon, 10 Nov 2025 09:47:27 -0800 Subject: [PATCH] wip python eval nodes --- comfy/cmd/main_pre.py | 1 + comfy_extras/eval_web/eval_python.js | 178 ++++++ comfy_extras/eval_web/ky_eval_python.js | 377 ------------- comfy_extras/nodes/nodes_eval.py | 187 ++++--- tests/unit/test_eval_nodes.py | 693 ++++++++++++++++++++++++ 5 files changed, 971 insertions(+), 465 deletions(-) create mode 100644 comfy_extras/eval_web/eval_python.js delete mode 100644 comfy_extras/eval_web/ky_eval_python.js create mode 100644 tests/unit/test_eval_nodes.py diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 249482188..7c06e4cee 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -42,6 +42,7 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_ warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*') +warnings.filterwarnings("ignore", category=UserWarning, message="Please use the new API settings to control TF32 behavior.*") warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning) warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning) warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning) diff --git a/comfy_extras/eval_web/eval_python.js b/comfy_extras/eval_web/eval_python.js new file mode 100644 index 000000000..d344d9e50 --- /dev/null +++ b/comfy_extras/eval_web/eval_python.js @@ -0,0 +1,178 @@ +/** + * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode + * + * MIT License + * + * Copyright (c) 2024 Kevin Yuan + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import { app } from "../../scripts/app.js"; +import { makeElement, findWidget } from "./ace_utils.js"; + +// Load Ace editor using script tag for Safari compatibility +// The noconflict build includes AMD loader that works in all browsers +let ace; +const aceLoadPromise = new Promise((resolve) => { + if (window.ace) { + ace = window.ace; + resolve(); + } else { + const script = document.createElement('script'); + script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js"; + script.onload = () => { + ace = window.ace; + ace.config.set("basePath", "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict"); + resolve(); + }; + document.head.appendChild(script); + } +}); + +await aceLoadPromise; + + +function getPosition(node, ctx, w_width, y, n_height) { + const margin = 5; + + const rect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + return { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e + rect.left}px`, + top: `${transform.d + transform.f + rect.top}px`, + maxWidth: `${w_width - margin * 2}px`, + maxHeight: `${n_height - margin * 2 - y - 15}px`, + width: `${w_width - margin * 2}px`, + height: "90%", + position: "absolute", + scrollbarColor: "var(--descrip-text) var(--bg-color)", + scrollbarWidth: "thin", + zIndex: app.graph._nodes.indexOf(node), + }; +} + +// Create code editor widget +function codeEditor(node, inputName, inputData) { + const widget = { + type: "pycode", + name: inputName, + options: { hideOnZoom: true }, + value: inputData[1]?.default || "", + draw(ctx, node, widgetWidth, y) { + const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden"; + + this.codeElement.hidden = hidden; + + if (hidden) { + this.options.onHide?.(this); + return; + } + + Object.assign(this.codeElement.style, getPosition(node, ctx, widgetWidth, y, node.size[1])); + }, + computeSize() { + return [500, 250]; + }, + }; + + widget.codeElement = makeElement("pre", { + innerHTML: widget.value, + }); + + widget.editor = ace.edit(widget.codeElement); + widget.editor.setTheme("ace/theme/monokai"); + widget.editor.session.setMode("ace/mode/python"); + widget.editor.setOptions({ + enableAutoIndent: true, + enableLiveAutocompletion: true, + enableBasicAutocompletion: true, + fontFamily: "monospace", + }); + widget.codeElement.hidden = true; + + document.body.appendChild(widget.codeElement); + + const originalCollapse = node.collapse; + node.collapse = function () { + originalCollapse.apply(this, arguments); + widget.codeElement.hidden = !!this.flags?.collapsed; + }; + + return widget; +} + +// Trigger workflow change tracking +function markWorkflowChanged() { + app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); +} + +// Register extensions +app.registerExtension({ + name: "Comfy.EvalPython", + getCustomWidgets(app) { + return { + PYCODE: (node, inputName, inputData) => { + const widget = codeEditor(node, inputName, inputData); + + widget.editor.getSession().on("change", () => { + widget.value = widget.editor.getValue(); + markWorkflowChanged(); + }); + + node.onRemoved = function () { + for (const w of this.widgets) { + if (w?.codeElement) { + w.codeElement.remove(); + } + } + }; + + node.addCustomWidget(widget); + + return widget; + }, + }; + }, + + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name === "EvalPython") { + const originalOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + originalOnConfigure?.apply(this, arguments); + + if (info?.widgets_values?.length) { + const widgetCodeIndex = findWidget(this, "pycode", "type", "findIndex"); + const editor = this.widgets[widgetCodeIndex]?.editor; + + if (editor) { + editor.setValue(info.widgets_values[widgetCodeIndex]); + editor.clearSelection(); + } + } + }; + } + }, +}); diff --git a/comfy_extras/eval_web/ky_eval_python.js b/comfy_extras/eval_web/ky_eval_python.js deleted file mode 100644 index 3d65aa5c0..000000000 --- a/comfy_extras/eval_web/ky_eval_python.js +++ /dev/null @@ -1,377 +0,0 @@ -/** - * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode - * - * MIT License - * - * Copyright (c) 2024 Kevin Yuan - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -import { app } from "../../scripts/app.js"; - -import * as ace from "https://cdn.jsdelivr.net/npm/ace-code@1.43.4/+esm"; -import { makeElement, findWidget } from "./ace_utils.js"; - -// Constants -const varTypes = ["int", "boolean", "string", "float", "json", "list", "dict"]; -const typeMap = { - int: "int", - boolean: "bool", - string: "str", - float: "float", - json: "json", - list: "list", - dict: "dict", -}; - -ace.config.setModuleLoader('ace/mode/python', () => - import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/mode-python.js') -); - -ace.config.setModuleLoader('ace/theme/monokai', () => - import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/theme-monokai.js') -); - -function getPostition(node, ctx, w_width, y, n_height) { - const margin = 5; - - const rect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); - - return { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e + rect.left}px`, - top: `${transform.d + transform.f + rect.top}px`, - maxWidth: `${w_width - margin * 2}px`, - maxHeight: `${n_height - margin * 2 - y - 15}px`, - width: `${w_width - margin * 2}px`, - height: "90%", - position: "absolute", - scrollbarColor: "var(--descrip-text) var(--bg-color)", - scrollbarWidth: "thin", - zIndex: app.graph._nodes.indexOf(node), - }; -} - -// Create editor code -function codeEditor(node, inputName, inputData) { - const widget = { - type: "pycode", - name: inputName, - options: { hideOnZoom: true }, - value: - inputData[1]?.default || - `def my(a, b=1): - return a * b
- -r0 = str(my(23, 9))`, - draw(ctx, node, widget_width, y, widget_height) { - const hidden = node.flags?.collapsed || (!!widget.options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.type === "converted-widget" || widget.type === "hidden"; - - widget.codeElement.hidden = hidden; - - if (hidden) { - widget.options.onHide?.(widget); - return; - } - - Object.assign(this.codeElement.style, getPostition(node, ctx, widget_width, y, node.size[1])); - }, - computeSize(...args) { - return [500, 250]; - }, - }; - - widget.codeElement = makeElement("pre", { - innerHTML: widget.value, - }); - - widget.editor = ace.edit(widget.codeElement); - widget.editor.setTheme("ace/theme/monokai"); - widget.editor.session.setMode("ace/mode/python"); - widget.editor.setOptions({ - enableAutoIndent: true, - enableLiveAutocompletion: true, - enableBasicAutocompletion: true, - fontFamily: "monospace", - }); - widget.codeElement.hidden = true; - - document.body.appendChild(widget.codeElement); - - const collapse = node.collapse; - node.collapse = function () { - collapse.apply(this, arguments); - if (this.flags?.collapsed) { - widget.codeElement.hidden = true; - } else { - if (this.flags?.collapsed === false) { - widget.codeElement.hidden = false; - } - } - }; - - return widget; -} - -// Save data to workflow forced! -function saveValue() { - app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); -} - -// Register extensions -app.registerExtension({ - name: "KYNode.KY_Eval_Python", - getCustomWidgets(app) { - return { - PYCODE: (node, inputName, inputData, app) => { - const widget = codeEditor(node, inputName, inputData); - - widget.editor.getSession().on("change", function (e) { - widget.value = widget.editor.getValue(); - saveValue(); - }); - - const varTypeList = node.addWidget( - "combo", - "select_type", - "string", - (v) => { - // widget.editor.setTheme(`ace/theme/${varTypeList.value}`); - }, - { - values: varTypes, - serialize: false, - }, - ); - - // 6. 使用 addDOMWidget 将容器添加到节点上 - // - 第一个参数是 widget 的名称,在节点内部需要是唯一的。 - // - 第二个参数是 widget 的类型,对于自定义 DOM 元素,通常是 "div"。 - // - 第三个参数是您创建的 DOM 元素。 - // - 第四个参数是一个选项对象,可以用来配置 widget。 - // node.addDOMWidget("rowOfButtons", "div", container, { - // }); - node.addWidget("button", "Add Input variable", "add_input_variable", async () => { - // Input name variable and check - let nameInput = node?.inputs?.length ? `p${node.inputs.length - 1}` : "p0"; - - const currentWidth = node.size[0]; - let tp = varTypeList.value; - nameInput = nameInput + "_" + typeMap[tp]; - node.addInput(nameInput, "*"); - node.setSize([currentWidth, node.size[1]]); - let cv = widget.editor.getValue(); - if (tp === "json") { - cv = cv + "\n" + nameInput + " = json.loads(" + nameInput + ")"; - } else if (tp === "list") { - cv = cv + "\n" + nameInput + " = []"; - } else if (tp === "dict") { - cv = cv + "\n" + nameInput + " = {}"; - } else { - cv = cv + "\n" + nameInput + " = " + typeMap[tp] + "(" + nameInput + ")"; - } - widget.editor.setValue(cv); - saveValue(); - }); - - node.addWidget("button", "Add Output variable", "add_output_variable", async () => { - const currentWidth = node.size[0]; - // Output name variable - let nameOutput = node?.outputs?.length ? `r${node.outputs.length}` : "r0"; - let tp = varTypeList.value; - nameOutput = nameOutput + "_" + typeMap[tp]; - node.addOutput(nameOutput, tp); - node.setSize([currentWidth, node.size[1]]); - let cv = widget.editor.getValue(); - if (tp === "json") { - cv = cv + "\n" + nameOutput + " = json.dumps(" + nameOutput + ")"; - } else if (tp === "list") { - cv = cv + "\n" + nameOutput + " = []"; - } else if (tp === "dict") { - cv = cv + "\n" + nameOutput + " = {}"; - } else { - cv = cv + "\n" + nameOutput + " = " + typeMap[tp] + "(" + nameOutput + ")"; - } - widget.editor.setValue(cv); - saveValue(); - }); - - node.onRemoved = function () { - for (const w of node?.widgets) { - if (w?.codeElement) w.codeElement.remove(); - } - }; - - node.addCustomWidget(widget); - - return widget; - }, - }; - }, - - async beforeRegisterNodeDef(nodeType, nodeData, app) { - // --- IDENode - if (nodeData.name === "KY_Eval_Python") { - // Node Created - const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = async function () { - const ret = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; - - const node_title = await this.getTitle(); - const nodeName = `${nodeData.name}_${this.id}`; - - this.name = nodeName; - - // Create default inputs, when first create node - if (this?.inputs?.length < 2) { - ["p0_str"].forEach((inputName) => { - const currentWidth = this.size[0]; - this.addInput(inputName, "*"); - this.setSize([currentWidth, this.size[1]]); - }); - } - - const widgetEditor = findWidget(this, "pycode", "type"); - - this.setSize([530, this.size[1]]); - - return ret; - }; - - const onDrawForeground = nodeType.prototype.onDrawForeground; - nodeType.prototype.onDrawForeground = function (ctx) { - const r = onDrawForeground?.apply?.(this, arguments); - - // if (this.flags?.collapsed) return r; - - if (this?.outputs?.length) { - for (let o = 0; o < this.outputs.length; o++) { - const { name, type } = this.outputs[o]; - const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; - const nameSize = ctx.measureText(name); - const typeSize = ctx.measureText(`[${type === "*" ? "any" : type.toLowerCase()}]`); - - ctx.fillStyle = colorType === "" ? "#AAA" : colorType; - ctx.font = "12px Arial, sans-serif"; - ctx.textAlign = "right"; - ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, this.size[0] - nameSize.width - typeSize.width, o * 20 + 19); - } - } - - if (this?.inputs?.length) { - const not_showing = ["select_type", "pycode"]; - for (let i = 1; i < this.inputs.length; i++) { - const { name, type } = this.inputs[i]; - if (not_showing.includes(name)) continue; - const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; - const nameSize = ctx.measureText(name); - - ctx.fillStyle = !colorType || colorType === "" ? "#AAA" : colorType; - ctx.font = "12px Arial, sans-serif"; - ctx.textAlign = "left"; - ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, nameSize.width + 25, i * 20); - } - } - return r; - }; - - // Node Configure - const onConfigure = nodeType.prototype.onConfigure; - nodeType.prototype.onConfigure = function (node) { - onConfigure?.apply(this, arguments); - if (node?.widgets_values?.length) { - const widget_code_id = findWidget(this, "pycode", "type", "findIndex"); - const widget_theme_id = findWidget(this, "varTypeList", "name", "findIndex"); - const widget_language_id = findWidget(this, "language", "name", "findIndex"); - - const editor = this.widgets[widget_code_id]?.editor; - - if (editor) { - // editor.setTheme( - // `ace/theme/${this.widgets_values[widget_theme_id]}` - // ); - // editor.session.setMode( - // `ace/mode/${this.widgets_values[widget_language_id]}` - // ); - editor.setValue(this.widgets_values[widget_code_id]); - editor.clearSelection(); - } - } - }; - - // ExtraMenuOptions - const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; - nodeType.prototype.getExtraMenuOptions = function (_, options) { - getExtraMenuOptions?.apply(this, arguments); - - const past_index = options.length - 1; - const past = options[past_index]; - - if (!!past) { - // Inputs remove - for (const input_idx in this.inputs) { - const input = this.inputs[input_idx]; - - if (["language", "select_type"].includes(input.name)) continue; - - options.splice(past_index + 1, 0, { - content: `Remove Input ${input.name}`, - callback: (e) => { - const currentWidth = this.size[0]; - if (input.link) { - app.graph.removeLink(input.link); - } - this.removeInput(input_idx); - this.setSize([80, this.size[1]]); - saveValue(); - }, - }); - } - - // Output remove - for (const output_idx in this.outputs) { - const output = this.outputs[output_idx]; - - if (output.name === "r0") continue; - - options.splice(past_index + 1, 0, { - content: `Remove Output ${output.name}`, - callback: (e) => { - const currentWidth = this.size[0]; - if (output.link) { - app.graph.removeLink(output.link); - } - this.removeOutput(output_idx); - this.setSize([currentWidth, this.size[1]]); - saveValue(); - }, - }); - } - } - }; - // end - ExtraMenuOptions - } - }, -}); diff --git a/comfy_extras/nodes/nodes_eval.py b/comfy_extras/nodes/nodes_eval.py index a09739c21..ff04522eb 100644 --- a/comfy_extras/nodes/nodes_eval.py +++ b/comfy_extras/nodes/nodes_eval.py @@ -1,109 +1,120 @@ -import re -import traceback -import types +import logging +from comfy.comfy_types import IO from comfy.execution_context import current_execution_context from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes from comfy.nodes.package_typing import CustomNode -remove_type_name = re.compile(r"(\{.*\})", re.I | re.M) +logger = logging.getLogger(__name__) -# Hack: string type that is always equal in not equal comparisons, thanks pythongosssss -class AnyType(str): - def __ne__(self, __value: object) -> bool: - return False +def eval_python(inputs=5, outputs=5, name=None, input_is_list=None, output_is_list=None): + """ + Factory function to create EvalPython node classes with configurable input/output counts. + Args: + inputs: Number of input value slots (default: 5) + outputs: Number of output item slots (default: 5) + name: Class name (default: f"EvalPython_{inputs}_{outputs}") + input_is_list: Optional list of bools indicating which inputs accept lists (default: None, meaning all scalar) + output_is_list: Optional tuple of bools indicating which outputs return lists (default: None, meaning all scalar) -PY_CODE = AnyType("*") -IDEs_DICT = {} + Returns: + A CustomNode subclass configured with the specified inputs/outputs + """ + if name is None: + name = f"EvalPython_{inputs}_{outputs}" - -# - Thank you very much for the class -> Trung0246 - -# - https://github.com/Trung0246/ComfyUI-0246/blob/main/utils.py#L51 -class TautologyStr(str): - def __ne__(self, other): - return False - - -class ByPassTypeTuple(tuple): - def __getitem__(self, index): - if index > 0: - index = 0 - item = super().__getitem__(index) - if isinstance(item, str): - return TautologyStr(item) - return item - - -# --------------------------- - - -class KY_Eval_Python(CustomNode): - @classmethod - def INPUT_TYPES(s): - - return { - "required": { - "pycode": ( - "PYCODE", - { - "default": """import re, json, os, traceback -from time import strftime - -def runCode(): - nowDataTime = strftime("%Y-%m-%d %H:%M:%S") - return f"Hello ComfyUI with us today {nowDataTime}!" -r0_str = runCode() + unique_id + default_code = f""" +print("Hello World!") +return {", ".join([f"value{i}" for i in range(inputs)])} """ - }, - ), - }, - "hidden": {"unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = ByPassTypeTuple((PY_CODE,)) - RETURN_NAMES = ("r0_str",) - FUNCTION = "exec_py" - DESCRIPTION = "IDE Node is an node that allows you to run code written in Python or Javascript directly in the node." - CATEGORY = "KYNode/Code" + class EvalPythonNode(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "pycode": ( + "PYCODE", + { + "default": default_code + }, + ), + }, + "optional": {f"value{i}": (IO.ANY, {}) for i in range(inputs)}, + } - def exec_py(self, pycode, unique_id, extra_pnginfo, **kwargs): - ctx = current_execution_context() - if ctx.configuration.enable_eval is not True: - raise ValueError("Python eval is disabled") + RETURN_TYPES = tuple(IO.ANY for _ in range(outputs)) + RETURN_NAMES = tuple(f"item{i}" for i in range(outputs)) + OUTPUT_IS_LIST = output_is_list + INPUT_IS_LIST = input_is_list is not None + FUNCTION = "exec_py" + DESCRIPTION = "" + CATEGORY = "eval" - if unique_id not in IDEs_DICT: - IDEs_DICT[unique_id] = self + def exec_py(self, pycode, **kwargs): + ctx = current_execution_context() - outputs = {unique_id: unique_id} - if extra_pnginfo and 'workflow' in extra_pnginfo and extra_pnginfo['workflow']: - for node in extra_pnginfo['workflow']['nodes']: - if node['id'] == int(unique_id): - outputs_valid = [ouput for ouput in node.get('outputs', []) if ouput.get('name', '') != '' and ouput.get('type', '') != ''] - outputs = {ouput['name']: None for ouput in outputs_valid} - self.RETURN_TYPES = ByPassTypeTuple(out["type"] for out in outputs_valid) - self.RETURN_NAMES = tuple(name for name in outputs.keys()) - my_namespace = types.SimpleNamespace() - # 从 prompt 对象中提取 prompt_id - # if extra_data and 'extra_data' in extra_data and 'prompt_id' in extra_data['extra_data']: - # prompt_id = prompt['extra_data']['prompt_id'] - # outputs['p0_str'] = p0_str + # Ensure all value inputs have a default of None + kwargs = { + **{f"value{i}": None for i in range(inputs)}, + **kwargs, + } - my_namespace.__dict__.update(outputs) - my_namespace.__dict__.update({prop: kwargs[prop] for prop in kwargs}) - # my_namespace.__dict__.setdefault("r0_str", "The r0 variable is not assigned") + def print(*args): + ctx.server.send_progress_text(" ".join(map(str, args)), ctx.node_id) - try: - exec(pycode, my_namespace.__dict__) - except Exception as e: - err = traceback.format_exc() - mc = re.search(r'line (\d+), in ([\w\W]+)$', err, re.MULTILINE) - msg = mc[1] + ':' + mc[2] - my_namespace.r0 = f"Error Line{msg}" + if not ctx.configuration.enable_eval: + raise ValueError("Python eval is disabled") - new_dict = {key: my_namespace.__dict__[key] for key in my_namespace.__dict__ if key not in ['__builtins__', *kwargs.keys()] and not callable(my_namespace.__dict__[key])} - return (*new_dict.values(),) + # Extract value arguments in order + value_args = [kwargs.pop(f"value{i}") for i in range(inputs)] + arg_names = ", ".join(f"value{i}=None" for i in range(inputs)) + + # Wrap pycode in a function to support return statements + wrapped_code = f"def _eval_func({arg_names}):\n" + for line in pycode.splitlines(): + wrapped_code += " " + line + "\n" + + globals_for_eval = { + **kwargs, + "logger": logger, + "print": print, + } + + # Execute wrapped function definition + exec(wrapped_code, globals_for_eval) + + # Call the function with value arguments + results = globals_for_eval["_eval_func"](*value_args) + + # Normalize results to match output count + if not isinstance(results, tuple): + results = (results,) + + if len(results) < outputs: + results += (None,) * (outputs - len(results)) + elif len(results) > outputs: + results = results[:outputs] + + return results + + # Set the class name for better debugging/introspection + EvalPythonNode.__name__ = name + EvalPythonNode.__qualname__ = name + + return EvalPythonNode + + +# Create the default EvalPython node with 5 inputs and 5 outputs +EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5") +EvalPython = EvalPython_5_5 # Backward compatibility alias + +# Create list variants +EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None) +EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,)) +EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,)) export_custom_nodes() diff --git a/tests/unit/test_eval_nodes.py b/tests/unit/test_eval_nodes.py new file mode 100644 index 000000000..f2cb0c763 --- /dev/null +++ b/tests/unit/test_eval_nodes.py @@ -0,0 +1,693 @@ +import pytest +from unittest.mock import Mock, patch + +from comfy.cli_args import default_configuration +from comfy.execution_context import context_configuration +from comfy_extras.nodes.nodes_eval import ( + EvalPython, + EvalPython_5_5, + eval_python, + EvalPython_List_1, + EvalPython_1_List, + EvalPython_List_List, +) + + +@pytest.fixture +def eval_context(): + """Fixture that sets up execution context with eval enabled""" + config = default_configuration() + config.enable_eval = True + with context_configuration(config): + yield + + +def test_eval_python_basic_return(eval_context): + """Test basic return statement with single value""" + node = EvalPython_5_5() + result = node.exec_py(pycode="return 42", value0=0, value1=1, value2=2, value3=3, value4=4) + assert result == (42, None, None, None, None) + + +def test_eval_python_multiple_returns(eval_context): + """Test return statement with tuple of values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, None, None) + + +def test_eval_python_all_five_returns(eval_context): + """Test return statement with all five values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 'a', 'b', 'c', 'd', 'e'", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ('a', 'b', 'c', 'd', 'e') + + +def test_eval_python_excess_returns(eval_context): + """Test that excess return values are truncated to 5""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5, 6, 7", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_use_value_args(eval_context): + """Test that value arguments are accessible in pycode""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0 + value1 + value2", + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, None, None, None, None) + + +def test_eval_python_all_value_args(eval_context): + """Test all value arguments are accessible""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0, value1, value2, value3, value4", + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_computation(eval_context): + """Test computation with value arguments""" + node = EvalPython_5_5() + code = """ +x = value0 * 2 +y = value1 * 3 +z = x + y +return z +""" + result = node.exec_py( + pycode=code, + value0=5, value1=10, value2=0, value3=0, value4=0 + ) + assert result == (40, None, None, None, None) + + +def test_eval_python_multiline(eval_context): + """Test multiline code with conditionals""" + node = EvalPython_5_5() + code = """ +if value0 > 10: + result = "large" +else: + result = "small" +return result, value0 +""" + result = node.exec_py( + pycode=code, + value0=15, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", 15, None, None, None) + + +def test_eval_python_list_comprehension(eval_context): + """Test list comprehension and iteration""" + node = EvalPython_5_5() + code = """ +numbers = [value0, value1, value2] +doubled = [x * 2 for x in numbers] +return sum(doubled) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=2, value2=3, value3=0, value4=0 + ) + assert result == (12, None, None, None, None) + + +def test_eval_python_string_operations(eval_context): + """Test string operations""" + node = EvalPython_5_5() + code = """ +s1 = str(value0) +s2 = str(value1) +return s1 + s2, len(s1 + s2) +""" + result = node.exec_py( + pycode=code, + value0=123, value1=456, value2=0, value3=0, value4=0 + ) + assert result == ("123456", 6, None, None, None) + + +def test_eval_python_type_mixing(eval_context): + """Test mixing different types""" + node = EvalPython_5_5() + code = """ +return value0, str(value1), float(value2), bool(value3) +""" + result = node.exec_py( + pycode=code, + value0=42, value1=100, value2=3, value3=1, value4=0 + ) + assert result == (42, "100", 3.0, True, None) + + +def test_eval_python_logger_available(eval_context): + """Test that logger is available in eval context""" + node = EvalPython_5_5() + code = """ +logger.info("test log") +return "success" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("success", None, None, None, None) + + +def test_eval_python_print_available(eval_context): + """Test that print function is available""" + node = EvalPython_5_5() + code = """ +print("Hello World!") +return "printed" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("printed", None, None, None, None) + +def test_eval_python_print_is_called(eval_context): + """Test that print function is called and receives correct arguments""" + node = EvalPython_5_5() + + # Track print calls + print_calls = [] + + code = """ +print("Hello", "World") +print("Line 2") +return "done" +""" + + # Mock exec to capture the globals dict and verify print is there + original_exec = exec + captured_globals = {} + + def mock_exec(code_str, globals_dict, *args, **kwargs): + # Capture the globals dict + captured_globals.update(globals_dict) + + # Wrap the print function to track calls + original_print = globals_dict.get('print') + if original_print: + def tracked_print(*args): + print_calls.append(args) + return original_print(*args) + globals_dict['print'] = tracked_print + + # Run the original exec + return original_exec(code_str, globals_dict, *args, **kwargs) + + with patch('builtins.exec', side_effect=mock_exec): + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print was in the globals + assert 'print' in captured_globals + + # Verify print was called twice with correct arguments + assert len(print_calls) == 2 + assert print_calls[0] == ("Hello", "World") + assert print_calls[1] == ("Line 2",) + + +def test_eval_python_print_sends_to_server(eval_context): + """Test that print sends messages to PromptServer via context""" + from comfy.execution_context import current_execution_context + + node = EvalPython_5_5() + ctx = current_execution_context() + + # Mock the server's send_progress_text method + original_send = ctx.server.send_progress_text if hasattr(ctx.server, 'send_progress_text') else None + mock_send = Mock() + ctx.server.send_progress_text = mock_send + + code = """ +print("Hello", "World") +print("Value:", value0) +return "done" +""" + + try: + result = node.exec_py( + pycode=code, + value0=42, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print messages were sent to server + assert mock_send.call_count == 2 + + # Verify the messages sent + calls = mock_send.call_args_list + assert calls[0][0][0] == "Hello World" + assert calls[0][0][1] == ctx.node_id + assert calls[1][0][0] == "Value: 42" + assert calls[1][0][1] == ctx.node_id + finally: + # Restore original + if original_send: + ctx.server.send_progress_text = original_send + + +def test_eval_python_config_disabled_raises(): + """Test that enable_eval=False raises an error""" + node = EvalPython_5_5() + config = default_configuration() + config.enable_eval = False + with context_configuration(config): + with pytest.raises(ValueError, match="Python eval is disabled"): + node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + +def test_eval_python_config_enabled_works(eval_context): + """Test that enable_eval=True allows execution""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (42, None, None, None, None) + + +def test_eval_python_default_code(eval_context): + """Test the default code example works""" + node = EvalPython_5_5() + # Get the default code from INPUT_TYPES + default_code = EvalPython_5_5.INPUT_TYPES()["required"]["pycode"][1]["default"] + + result = node.exec_py( + pycode=default_code, + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + # Default code prints and returns the values + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_function_definition(eval_context): + """Test defining and using functions""" + node = EvalPython_5_5() + code = """ +def multiply(a, b): + return a * b + +result = multiply(value0, value1) +return result +""" + result = node.exec_py( + pycode=code, + value0=7, value1=8, value2=0, value3=0, value4=0 + ) + assert result == (56, None, None, None, None) + + +def test_eval_python_nested_functions(eval_context): + """Test nested function definitions""" + node = EvalPython_5_5() + code = """ +def outer(x): + def inner(y): + return y * 2 + return inner(x) + 10 + +result = outer(value0) +return result +""" + result = node.exec_py( + pycode=code, + value0=5, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (20, None, None, None, None) + + +def test_eval_python_dict_operations(eval_context): + """Test dictionary creation and operations""" + node = EvalPython_5_5() + code = """ +data = { + 'a': value0, + 'b': value1, + 'c': value2 +} +return sum(data.values()), len(data) +""" + result = node.exec_py( + pycode=code, + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, 3, None, None, None) + + +def test_eval_python_list_operations(eval_context): + """Test list creation and operations""" + node = EvalPython_5_5() + code = """ +items = [value0, value1, value2, value3, value4] +filtered = [x for x in items if x > 5] +return len(filtered), sum(filtered) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=10, value2=3, value3=15, value4=2 + ) + assert result == (2, 25, None, None, None) + + +def test_eval_python_early_return(eval_context): + """Test early return in conditional""" + node = EvalPython_5_5() + code = """ +if value0 > 100: + return "large" +return "small" +""" + result = node.exec_py( + pycode=code, + value0=150, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", None, None, None, None) + + +def test_eval_python_loop_with_return(eval_context): + """Test loop with return statement""" + node = EvalPython_5_5() + code = """ +total = 0 +for i in range(value0): + total += i +return total +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (45, None, None, None, None) + + +def test_eval_python_exception_handling(eval_context): + """Test try/except blocks""" + node = EvalPython_5_5() + code = """ +try: + result = value0 / value1 +except ZeroDivisionError: + result = float('inf') +return result +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (float('inf'), None, None, None, None) + + +def test_eval_python_none_values(eval_context): + """Test handling None values in inputs""" + node = EvalPython_5_5() + code = """ +return value0, value1 is None, value2 is None +""" + result = node.exec_py( + pycode=code, + value0=42, value1=None, value2=None, value3=0, value4=0 + ) + assert result == (42, True, True, None, None) + + +def test_eval_python_input_types(): + """Test that INPUT_TYPES returns correct structure""" + input_types = EvalPython_5_5.INPUT_TYPES() + assert "required" in input_types + assert "optional" in input_types + assert "pycode" in input_types["required"] + assert input_types["required"]["pycode"][0] == "PYCODE" + + # Check optional inputs + for i in range(5): + assert f"value{i}" in input_types["optional"] + + +def test_eval_python_metadata(): + """Test node metadata""" + assert EvalPython_5_5.FUNCTION == "exec_py" + assert EvalPython_5_5.CATEGORY == "eval" + assert len(EvalPython_5_5.RETURN_TYPES) == 5 + assert len(EvalPython_5_5.RETURN_NAMES) == 5 + assert all(name.startswith("item") for name in EvalPython_5_5.RETURN_NAMES) + + +def test_eval_python_factory_custom_inputs_outputs(eval_context): + """Test creating nodes with custom input/output counts""" + # Create a node with 3 inputs and 2 outputs + CustomNode = eval_python(inputs=3, outputs=2) + + node = CustomNode() + + # Verify INPUT_TYPES has correct number of inputs + input_types = CustomNode.INPUT_TYPES() + assert len(input_types["optional"]) == 3 + assert "value0" in input_types["optional"] + assert "value1" in input_types["optional"] + assert "value2" in input_types["optional"] + assert "value3" not in input_types["optional"] + + # Verify RETURN_TYPES has correct number of outputs + assert len(CustomNode.RETURN_TYPES) == 2 + assert len(CustomNode.RETURN_NAMES) == 2 + + # Test execution + result = node.exec_py( + pycode="return value0 + value1 + value2, value0 * 2", + value0=1, value1=2, value2=3 + ) + assert result == (6, 2) + + +def test_eval_python_factory_custom_name(eval_context): + """Test creating nodes with custom names""" + CustomNode = eval_python(inputs=2, outputs=2, name="MyCustomEval") + + assert CustomNode.__name__ == "MyCustomEval" + assert CustomNode.__qualname__ == "MyCustomEval" + + +def test_eval_python_factory_default_name(eval_context): + """Test that default name follows pattern""" + CustomNode = eval_python(inputs=3, outputs=4) + + assert CustomNode.__name__ == "EvalPython_3_4" + assert CustomNode.__qualname__ == "EvalPython_3_4" + + +def test_eval_python_factory_single_output(eval_context): + """Test node with single output""" + SingleOutputNode = eval_python(inputs=2, outputs=1) + + node = SingleOutputNode() + result = node.exec_py( + pycode="return value0 + value1", + value0=10, value1=20 + ) + assert result == (30,) + + +def test_eval_python_factory_many_outputs(eval_context): + """Test node with many outputs""" + ManyOutputNode = eval_python(inputs=1, outputs=10) + + node = ManyOutputNode() + result = node.exec_py( + pycode="return tuple(range(10))", + value0=0 + ) + assert result == (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + + +def test_eval_python_factory_fewer_returns_than_outputs(eval_context): + """Test that fewer returns are padded with None""" + Node = eval_python(inputs=2, outputs=5) + + node = Node() + result = node.exec_py( + pycode="return value0, value1", + value0=1, value1=2 + ) + assert result == (1, 2, None, None, None) + + +def test_eval_python_factory_more_returns_than_outputs(eval_context): + """Test that excess returns are truncated""" + Node = eval_python(inputs=2, outputs=3) + + node = Node() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5", + value0=0, value1=0 + ) + assert result == (1, 2, 3) + + +def test_eval_python_list_1_input_is_list(eval_context): + """Test EvalPython_List_1 with list input""" + node = EvalPython_List_1() + + # Verify INPUT_IS_LIST is set + assert EvalPython_List_1.INPUT_IS_LIST is True + assert EvalPython_List_1.OUTPUT_IS_LIST is None + + # Test that value0 receives a list + result = node.exec_py( + pycode="return sum(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (15,) + + +def test_eval_python_list_1_iterate_list(eval_context): + """Test EvalPython_List_1 iterating over list input""" + node = EvalPython_List_1() + + result = node.exec_py( + pycode="return [x * 2 for x in value0]", + value0=[1, 2, 3] + ) + assert result == ([2, 4, 6],) + + +def test_eval_python_1_list_output_is_list(eval_context): + """Test EvalPython_1_List with list output""" + node = EvalPython_1_List() + + # Verify OUTPUT_IS_LIST is set + assert EvalPython_1_List.INPUT_IS_LIST is False + assert EvalPython_1_List.OUTPUT_IS_LIST == (True,) + + # Test that returns a list + result = node.exec_py( + pycode="return list(range(value0))", + value0=5 + ) + assert result == ([0, 1, 2, 3, 4],) + + +def test_eval_python_1_list_multiple_items(eval_context): + """Test EvalPython_1_List returning multiple items in list""" + node = EvalPython_1_List() + + result = node.exec_py( + pycode="return ['a', 'b', 'c']", + value0=0 + ) + assert result == (['a', 'b', 'c'],) + + +def test_eval_python_list_list_both(eval_context): + """Test EvalPython_List_List with both list input and output""" + node = EvalPython_List_List() + + # Verify both are set + assert EvalPython_List_List.INPUT_IS_LIST is True + assert EvalPython_List_List.OUTPUT_IS_LIST == (True,) + + # Test processing list input and returning list output + result = node.exec_py( + pycode="return [x ** 2 for x in value0]", + value0=[1, 2, 3, 4] + ) + assert result == ([1, 4, 9, 16],) + + +def test_eval_python_list_list_filter(eval_context): + """Test EvalPython_List_List filtering a list""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [x for x in value0 if x > 5]", + value0=[1, 3, 5, 7, 9, 11] + ) + assert result == ([7, 9, 11],) + + +def test_eval_python_list_list_transform(eval_context): + """Test EvalPython_List_List transforming list elements""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [str(x).upper() for x in value0]", + value0=['hello', 'world', 'python'] + ) + assert result == (['HELLO', 'WORLD', 'PYTHON'],) + + +def test_eval_python_factory_with_list_flags(eval_context): + """Test factory function with custom list flags""" + # Create node with input as list but output scalar + ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None) + + assert ListInputNode.INPUT_IS_LIST is True + assert ListInputNode.OUTPUT_IS_LIST is None + + node = ListInputNode() + result = node.exec_py( + pycode="return len(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (5,) + + +def test_eval_python_factory_scalar_output_list(eval_context): + """Test factory function with scalar input and list output""" + ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,)) + + assert ScalarToListNode.INPUT_IS_LIST is False + assert ScalarToListNode.OUTPUT_IS_LIST == (True,) + + node = ScalarToListNode() + result = node.exec_py( + pycode="return [value0] * 3", + value0='x' + ) + assert result == (['x', 'x', 'x'],) + + +def test_eval_python_list_empty_list(eval_context): + """Test list nodes with empty lists""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return []", + value0=[] + ) + assert result == ([],) + + +def test_eval_python_backward_compatibility(): + """Test that EvalPython is an alias for EvalPython_5_5""" + assert EvalPython is EvalPython_5_5