mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix issue finding approx vae taesdxl when used in a workflow
This commit is contained in:
parent
6bba743d62
commit
e62df3a881
@ -37,7 +37,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ValidationView
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ComboOptions
|
||||
from ..component_model.files import canonicalize_path
|
||||
from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||
@ -1038,10 +1038,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if isinstance(input_type, ComboOptions) or hasattr(input_type, "view_for_validation"):
|
||||
input_type = input_type.view_for_validation()
|
||||
|
||||
if isinstance(input_type, list):
|
||||
combo_options = input_type
|
||||
if isinstance(combo_options, ValidationView):
|
||||
combo_options = combo_options.view_for_validation()
|
||||
if isinstance(val, str) and "\\" in val:
|
||||
# try to normalize paths for comparison purposes
|
||||
val = canonicalize_path(val)
|
||||
|
||||
@ -291,6 +291,6 @@ ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None,
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ValidationView(Protocol):
|
||||
def view_for_validation(self) -> Iterable[str]:
|
||||
class ComboOptions(Protocol):
|
||||
def view_for_validation(self) -> list[str]:
|
||||
...
|
||||
|
||||
@ -523,10 +523,21 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
], folder_name="controlnet")
|
||||
|
||||
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors", show_in_ui=False),
|
||||
HuggingFile("madebyollin/taesd", "taesd_encoder.safetensors", show_in_ui=False),
|
||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors", show_in_ui=False),
|
||||
HuggingFile("madebyollin/taesdxl", "taesdxl_encoder.safetensors", show_in_ui=False),
|
||||
# todo: these are both the encoder and decoder, so it is not clear what should be done here
|
||||
# HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors", show_in_ui=False),
|
||||
# HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_decoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_encoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_encoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_decoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_encoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_decoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_encoder.pth", show_in_ui=False),
|
||||
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_decoder.pth", show_in_ui=False),
|
||||
], folder_name="vae_approx")
|
||||
|
||||
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Optional, List, Sequence, Union, Iterable
|
||||
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from .component_model.executor_types import ValidationView
|
||||
from .component_model.executor_types import ComboOptions
|
||||
from .component_model.files import canonicalize_path
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from .component_model.files import canonicalize_path
|
||||
class UrlFile:
|
||||
_url: str
|
||||
_save_with_filename: Optional[str] = None
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
return self.save_with_filename
|
||||
@ -58,6 +59,7 @@ class CivitFile:
|
||||
model_version_id: int
|
||||
filename: str
|
||||
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
return self.filename
|
||||
@ -95,7 +97,7 @@ class HuggingFile:
|
||||
return self.save_with_filename or split(self.filename)[-1]
|
||||
|
||||
|
||||
class DownloadableFileList(ValidationView, list[str]):
|
||||
class DownloadableFileList(ComboOptions, list[str]):
|
||||
"""
|
||||
A list of downloadable files that can be validated differently than it will be serialized to JSON
|
||||
"""
|
||||
@ -116,8 +118,8 @@ class DownloadableFileList(ValidationView, list[str]):
|
||||
|
||||
self.extend(sorted(list(map(canonicalize_path, ui_view))))
|
||||
|
||||
def view_for_validation(self) -> Iterable[str]:
|
||||
return self._validation_view
|
||||
def view_for_validation(self) -> list[str]:
|
||||
return sorted(list(frozenset(self._validation_view) | frozenset(self)))
|
||||
|
||||
|
||||
class CivitStats(TypedDict):
|
||||
|
||||
@ -772,7 +772,7 @@ class VAELoader:
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
def load_taesd(name):
|
||||
def load_taesd(name: str):
|
||||
sd_ = {}
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
|
||||
|
||||
@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Optional, Type, Literal
|
||||
from typing import Optional, Type, Literal, Union, NamedTuple
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||
DependencyExecutionErrorMessage
|
||||
DependencyExecutionErrorMessage, ComboOptions
|
||||
from comfy.nodes_context import get_nodes
|
||||
from .graph_utils import is_link
|
||||
|
||||
@ -59,11 +60,21 @@ class DynamicPrompt:
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
class InputInfoTuple(NamedTuple):
|
||||
# the input types can be a name or a combobox
|
||||
input_type: str | list[int] | ComboOptions
|
||||
input_category: Literal["required", "optional", "hidden"]
|
||||
extra_info: InputTypeOptions
|
||||
|
||||
|
||||
GetInputInfoResponse = InputInfoTuple | tuple[None, None, None]
|
||||
|
||||
|
||||
def get_input_info(
|
||||
class_def: Type[ComfyNodeABC],
|
||||
class_def: Union[Type[ComfyNodeABC], CustomNode],
|
||||
input_name: str,
|
||||
valid_inputs: InputTypeDict | None = None
|
||||
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||
valid_inputs: Union[InputTypeDict, InputTypes] = None
|
||||
) -> GetInputInfoResponse:
|
||||
"""Get the input type, category, and extra info for a given input name.
|
||||
|
||||
Arguments:
|
||||
|
||||
117
tests/unit/test_approx_vae.py
Normal file
117
tests/unit/test_approx_vae.py
Normal file
@ -0,0 +1,117 @@
|
||||
import pathlib
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
|
||||
|
||||
def build_workflow(vae_encoder_option="", vae_decoder_option=""):
|
||||
return {
|
||||
"1": {
|
||||
"inputs": {
|
||||
"vae_name": vae_encoder_option
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"pixels": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"1",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEEncode",
|
||||
"_meta": {
|
||||
"title": "VAE Encode"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1,
|
||||
"color": 0
|
||||
},
|
||||
"class_type": "EmptyImage",
|
||||
"_meta": {
|
||||
"title": "EmptyImage"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"2",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"vae_name": vae_decoder_option
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"filename_prefix": "test",
|
||||
"images": [
|
||||
"4",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_approx_vae_found(tmp_path_factory):
|
||||
workflow = build_workflow("taesdxl", "taesdxl")
|
||||
configuration = Configuration()
|
||||
base_dir = tmp_path_factory.mktemp("test_approx_vae_found")
|
||||
configuration.base_directory = str(base_dir)
|
||||
async with Comfy(configuration) as comfy:
|
||||
from comfy.nodes.base_nodes import VAELoader
|
||||
|
||||
assert "taesdxl" not in VAELoader.vae_list(), "should not be downloadable"
|
||||
with pytest.raises(ValueError):
|
||||
# should not attempt to download
|
||||
await comfy.queue_prompt_api(workflow)
|
||||
|
||||
# download both vaes
|
||||
taesdxl_decoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
|
||||
taesdxl_encoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_encoder.safetensors")
|
||||
assert taesdxl_decoder_path is not None
|
||||
assert taesdxl_encoder_path is not None
|
||||
|
||||
vae_approx_dir = base_dir / "models" / "vae_approx"
|
||||
vae_approx_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shutil.copy(taesdxl_encoder_path, vae_approx_dir / "taesdxl_encoder.safetensors")
|
||||
shutil.copy(taesdxl_decoder_path, vae_approx_dir / "taesdxl_decoder.safetensors")
|
||||
|
||||
# now should work
|
||||
await comfy.queue_prompt_api(workflow)
|
||||
Loading…
Reference in New Issue
Block a user