From e62df3a8811d8c652a195d4669f4fb27f6c9a9ba Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Fri, 26 Sep 2025 12:28:48 -0700 Subject: [PATCH] Fix issue finding approx vae taesdxl when used in a workflow --- comfy/cmd/execution.py | 7 +- comfy/component_model/executor_types.py | 4 +- comfy/model_downloader.py | 19 +++- comfy/model_downloader_types.py | 10 +- comfy/nodes/base_nodes.py | 2 +- comfy_execution/graph.py | 21 ++++- tests/unit/test_approx_vae.py | 117 ++++++++++++++++++++++++ 7 files changed, 161 insertions(+), 19 deletions(-) create mode 100644 tests/unit/test_approx_vae.py diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 7fdbde088..a7c67e723 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -37,7 +37,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ - HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ValidationView + HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ComboOptions from ..component_model.files import canonicalize_path from ..component_model.module_property import create_module_properties from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \ @@ -1038,10 +1038,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing errors.append(error) continue + if isinstance(input_type, ComboOptions) or hasattr(input_type, "view_for_validation"): + input_type = input_type.view_for_validation() + if isinstance(input_type, list): combo_options = input_type - if isinstance(combo_options, ValidationView): - combo_options = combo_options.view_for_validation() if isinstance(val, str) and "\\" in val: # try to normalize paths for comparison purposes val = canonicalize_path(val) diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index d168f2023..e5fb8e59c 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -291,6 +291,6 @@ ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None, @runtime_checkable -class ValidationView(Protocol): - def view_for_validation(self) -> Iterable[str]: +class ComboOptions(Protocol): + def view_for_validation(self) -> list[str]: ... diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index ed63bde9f..1fd929f6a 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -523,10 +523,21 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ ], folder_name="controlnet") KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([ - HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), - HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"), - HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"), - HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"), + HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors", show_in_ui=False), + HuggingFile("madebyollin/taesd", "taesd_encoder.safetensors", show_in_ui=False), + HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors", show_in_ui=False), + HuggingFile("madebyollin/taesdxl", "taesdxl_encoder.safetensors", show_in_ui=False), + # todo: these are both the encoder and decoder, so it is not clear what should be done here + # HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors", show_in_ui=False), + # HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_decoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_encoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_encoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_decoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_encoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_decoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_encoder.pth", show_in_ui=False), + UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_decoder.pth", show_in_ui=False), ], folder_name="vae_approx") KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index e80438fd0..749156c85 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -10,7 +10,7 @@ from typing import Optional, List, Sequence, Union, Iterable from can_ada import parse, URL # pylint: disable=no-name-in-module from typing_extensions import TypedDict, NotRequired -from .component_model.executor_types import ValidationView +from .component_model.executor_types import ComboOptions from .component_model.files import canonicalize_path @@ -18,6 +18,7 @@ from .component_model.files import canonicalize_path class UrlFile: _url: str _save_with_filename: Optional[str] = None + show_in_ui: Optional[bool] = True def __str__(self): return self.save_with_filename @@ -58,6 +59,7 @@ class CivitFile: model_version_id: int filename: str trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple) + show_in_ui: Optional[bool] = True def __str__(self): return self.filename @@ -95,7 +97,7 @@ class HuggingFile: return self.save_with_filename or split(self.filename)[-1] -class DownloadableFileList(ValidationView, list[str]): +class DownloadableFileList(ComboOptions, list[str]): """ A list of downloadable files that can be validated differently than it will be serialized to JSON """ @@ -116,8 +118,8 @@ class DownloadableFileList(ValidationView, list[str]): self.extend(sorted(list(map(canonicalize_path, ui_view)))) - def view_for_validation(self) -> Iterable[str]: - return self._validation_view + def view_for_validation(self) -> list[str]: + return sorted(list(frozenset(self._validation_view) | frozenset(self))) class CivitStats(TypedDict): diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index a4391da5c..8884e10b6 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -772,7 +772,7 @@ class VAELoader: return vaes @staticmethod - def load_taesd(name): + def load_taesd(name: str): sd_ = {} approx_vaes = folder_paths.get_filename_list("vae_approx") diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index f576bb92f..b3dfcf49b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -2,12 +2,13 @@ from __future__ import annotations import asyncio import inspect -from typing import Optional, Type, Literal +from typing import Optional, Type, Literal, Union, NamedTuple +from comfy.nodes.package_typing import CustomNode, InputTypes from comfy_execution.graph_utils import is_link, ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ - DependencyExecutionErrorMessage + DependencyExecutionErrorMessage, ComboOptions from comfy.nodes_context import get_nodes from .graph_utils import is_link @@ -59,11 +60,21 @@ class DynamicPrompt: return self.original_prompt +class InputInfoTuple(NamedTuple): + # the input types can be a name or a combobox + input_type: str | list[int] | ComboOptions + input_category: Literal["required", "optional", "hidden"] + extra_info: InputTypeOptions + + +GetInputInfoResponse = InputInfoTuple | tuple[None, None, None] + + def get_input_info( - class_def: Type[ComfyNodeABC], + class_def: Union[Type[ComfyNodeABC], CustomNode], input_name: str, - valid_inputs: InputTypeDict | None = None -) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + valid_inputs: Union[InputTypeDict, InputTypes] = None +) -> GetInputInfoResponse: """Get the input type, category, and extra info for a given input name. Arguments: diff --git a/tests/unit/test_approx_vae.py b/tests/unit/test_approx_vae.py new file mode 100644 index 000000000..102336624 --- /dev/null +++ b/tests/unit/test_approx_vae.py @@ -0,0 +1,117 @@ +import pathlib +import shutil + +import pytest +from huggingface_hub import hf_hub_download + +from comfy.cli_args_types import Configuration +from comfy.client.embedded_comfy_client import Comfy + + +def build_workflow(vae_encoder_option="", vae_decoder_option=""): + return { + "1": { + "inputs": { + "vae_name": vae_encoder_option + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "2": { + "inputs": { + "pixels": [ + "3", + 0 + ], + "vae": [ + "1", + 0 + ] + }, + "class_type": "VAEEncode", + "_meta": { + "title": "VAE Encode" + } + }, + "3": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1, + "color": 0 + }, + "class_type": "EmptyImage", + "_meta": { + "title": "EmptyImage" + } + }, + "4": { + "inputs": { + "samples": [ + "2", + 0 + ], + "vae": [ + "5", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "5": { + "inputs": { + "vae_name": vae_decoder_option + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "6": { + "inputs": { + "filename_prefix": "test", + "images": [ + "4", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } + } + + +async def test_approx_vae_found(tmp_path_factory): + workflow = build_workflow("taesdxl", "taesdxl") + configuration = Configuration() + base_dir = tmp_path_factory.mktemp("test_approx_vae_found") + configuration.base_directory = str(base_dir) + async with Comfy(configuration) as comfy: + from comfy.nodes.base_nodes import VAELoader + + assert "taesdxl" not in VAELoader.vae_list(), "should not be downloadable" + with pytest.raises(ValueError): + # should not attempt to download + await comfy.queue_prompt_api(workflow) + + # download both vaes + taesdxl_decoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_decoder.safetensors") + taesdxl_encoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_encoder.safetensors") + assert taesdxl_decoder_path is not None + assert taesdxl_encoder_path is not None + + vae_approx_dir = base_dir / "models" / "vae_approx" + vae_approx_dir.mkdir(parents=True, exist_ok=True) + + shutil.copy(taesdxl_encoder_path, vae_approx_dir / "taesdxl_encoder.safetensors") + shutil.copy(taesdxl_decoder_path, vae_approx_dir / "taesdxl_decoder.safetensors") + + # now should work + await comfy.queue_prompt_api(workflow)