Fix issue finding approx vae taesdxl when used in a workflow

This commit is contained in:
Benjamin Berman 2025-09-26 12:28:48 -07:00
parent 6bba743d62
commit e62df3a881
7 changed files with 161 additions and 19 deletions

View File

@ -37,7 +37,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ValidationView
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ComboOptions
from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
@ -1038,10 +1038,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
errors.append(error)
continue
if isinstance(input_type, ComboOptions) or hasattr(input_type, "view_for_validation"):
input_type = input_type.view_for_validation()
if isinstance(input_type, list):
combo_options = input_type
if isinstance(combo_options, ValidationView):
combo_options = combo_options.view_for_validation()
if isinstance(val, str) and "\\" in val:
# try to normalize paths for comparison purposes
val = canonicalize_path(val)

View File

@ -291,6 +291,6 @@ ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None,
@runtime_checkable
class ValidationView(Protocol):
def view_for_validation(self) -> Iterable[str]:
class ComboOptions(Protocol):
def view_for_validation(self) -> list[str]:
...

View File

@ -523,10 +523,21 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
], folder_name="controlnet")
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors", show_in_ui=False),
HuggingFile("madebyollin/taesd", "taesd_encoder.safetensors", show_in_ui=False),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors", show_in_ui=False),
HuggingFile("madebyollin/taesdxl", "taesdxl_encoder.safetensors", show_in_ui=False),
# todo: these are both the encoder and decoder, so it is not clear what should be done here
# HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors", show_in_ui=False),
# HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_decoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_encoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_encoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_decoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_encoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_decoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_encoder.pth", show_in_ui=False),
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_decoder.pth", show_in_ui=False),
], folder_name="vae_approx")
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -10,7 +10,7 @@ from typing import Optional, List, Sequence, Union, Iterable
from can_ada import parse, URL # pylint: disable=no-name-in-module
from typing_extensions import TypedDict, NotRequired
from .component_model.executor_types import ValidationView
from .component_model.executor_types import ComboOptions
from .component_model.files import canonicalize_path
@ -18,6 +18,7 @@ from .component_model.files import canonicalize_path
class UrlFile:
_url: str
_save_with_filename: Optional[str] = None
show_in_ui: Optional[bool] = True
def __str__(self):
return self.save_with_filename
@ -58,6 +59,7 @@ class CivitFile:
model_version_id: int
filename: str
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
show_in_ui: Optional[bool] = True
def __str__(self):
return self.filename
@ -95,7 +97,7 @@ class HuggingFile:
return self.save_with_filename or split(self.filename)[-1]
class DownloadableFileList(ValidationView, list[str]):
class DownloadableFileList(ComboOptions, list[str]):
"""
A list of downloadable files that can be validated differently than it will be serialized to JSON
"""
@ -116,8 +118,8 @@ class DownloadableFileList(ValidationView, list[str]):
self.extend(sorted(list(map(canonicalize_path, ui_view))))
def view_for_validation(self) -> Iterable[str]:
return self._validation_view
def view_for_validation(self) -> list[str]:
return sorted(list(frozenset(self._validation_view) | frozenset(self)))
class CivitStats(TypedDict):

View File

@ -772,7 +772,7 @@ class VAELoader:
return vaes
@staticmethod
def load_taesd(name):
def load_taesd(name: str):
sd_ = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")

View File

@ -2,12 +2,13 @@ from __future__ import annotations
import asyncio
import inspect
from typing import Optional, Type, Literal
from typing import Optional, Type, Literal, Union, NamedTuple
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage
DependencyExecutionErrorMessage, ComboOptions
from comfy.nodes_context import get_nodes
from .graph_utils import is_link
@ -59,11 +60,21 @@ class DynamicPrompt:
return self.original_prompt
class InputInfoTuple(NamedTuple):
# the input types can be a name or a combobox
input_type: str | list[int] | ComboOptions
input_category: Literal["required", "optional", "hidden"]
extra_info: InputTypeOptions
GetInputInfoResponse = InputInfoTuple | tuple[None, None, None]
def get_input_info(
class_def: Type[ComfyNodeABC],
class_def: Union[Type[ComfyNodeABC], CustomNode],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
valid_inputs: Union[InputTypeDict, InputTypes] = None
) -> GetInputInfoResponse:
"""Get the input type, category, and extra info for a given input name.
Arguments:

View File

@ -0,0 +1,117 @@
import pathlib
import shutil
import pytest
from huggingface_hub import hf_hub_download
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy
def build_workflow(vae_encoder_option="", vae_decoder_option=""):
return {
"1": {
"inputs": {
"vae_name": vae_encoder_option
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"2": {
"inputs": {
"pixels": [
"3",
0
],
"vae": [
"1",
0
]
},
"class_type": "VAEEncode",
"_meta": {
"title": "VAE Encode"
}
},
"3": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1,
"color": 0
},
"class_type": "EmptyImage",
"_meta": {
"title": "EmptyImage"
}
},
"4": {
"inputs": {
"samples": [
"2",
0
],
"vae": [
"5",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"5": {
"inputs": {
"vae_name": vae_decoder_option
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"6": {
"inputs": {
"filename_prefix": "test",
"images": [
"4",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
async def test_approx_vae_found(tmp_path_factory):
workflow = build_workflow("taesdxl", "taesdxl")
configuration = Configuration()
base_dir = tmp_path_factory.mktemp("test_approx_vae_found")
configuration.base_directory = str(base_dir)
async with Comfy(configuration) as comfy:
from comfy.nodes.base_nodes import VAELoader
assert "taesdxl" not in VAELoader.vae_list(), "should not be downloadable"
with pytest.raises(ValueError):
# should not attempt to download
await comfy.queue_prompt_api(workflow)
# download both vaes
taesdxl_decoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
taesdxl_encoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_encoder.safetensors")
assert taesdxl_decoder_path is not None
assert taesdxl_encoder_path is not None
vae_approx_dir = base_dir / "models" / "vae_approx"
vae_approx_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(taesdxl_encoder_path, vae_approx_dir / "taesdxl_encoder.safetensors")
shutil.copy(taesdxl_decoder_path, vae_approx_dir / "taesdxl_decoder.safetensors")
# now should work
await comfy.queue_prompt_api(workflow)