mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Fix issue finding approx vae taesdxl when used in a workflow
This commit is contained in:
parent
6bba743d62
commit
e62df3a881
@ -37,7 +37,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
|||||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
||||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ValidationView
|
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage, ComboOptions
|
||||||
from ..component_model.files import canonicalize_path
|
from ..component_model.files import canonicalize_path
|
||||||
from ..component_model.module_property import create_module_properties
|
from ..component_model.module_property import create_module_properties
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||||
@ -1038,10 +1038,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(input_type, ComboOptions) or hasattr(input_type, "view_for_validation"):
|
||||||
|
input_type = input_type.view_for_validation()
|
||||||
|
|
||||||
if isinstance(input_type, list):
|
if isinstance(input_type, list):
|
||||||
combo_options = input_type
|
combo_options = input_type
|
||||||
if isinstance(combo_options, ValidationView):
|
|
||||||
combo_options = combo_options.view_for_validation()
|
|
||||||
if isinstance(val, str) and "\\" in val:
|
if isinstance(val, str) and "\\" in val:
|
||||||
# try to normalize paths for comparison purposes
|
# try to normalize paths for comparison purposes
|
||||||
val = canonicalize_path(val)
|
val = canonicalize_path(val)
|
||||||
|
|||||||
@ -291,6 +291,6 @@ ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None,
|
|||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ValidationView(Protocol):
|
class ComboOptions(Protocol):
|
||||||
def view_for_validation(self) -> Iterable[str]:
|
def view_for_validation(self) -> list[str]:
|
||||||
...
|
...
|
||||||
|
|||||||
@ -523,10 +523,21 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
], folder_name="controlnet")
|
], folder_name="controlnet")
|
||||||
|
|
||||||
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors", show_in_ui=False),
|
||||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
HuggingFile("madebyollin/taesd", "taesd_encoder.safetensors", show_in_ui=False),
|
||||||
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
|
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors", show_in_ui=False),
|
||||||
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
|
HuggingFile("madebyollin/taesdxl", "taesdxl_encoder.safetensors", show_in_ui=False),
|
||||||
|
# todo: these are both the encoder and decoder, so it is not clear what should be done here
|
||||||
|
# HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors", show_in_ui=False),
|
||||||
|
# HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_decoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd_encoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_encoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesdxl_decoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_encoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taesd3_decoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_encoder.pth", show_in_ui=False),
|
||||||
|
UrlFile("https://raw.githubusercontent.com/madebyollin/taesd/main/taef1_decoder.pth", show_in_ui=False),
|
||||||
], folder_name="vae_approx")
|
], folder_name="vae_approx")
|
||||||
|
|
||||||
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from typing import Optional, List, Sequence, Union, Iterable
|
|||||||
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from .component_model.executor_types import ValidationView
|
from .component_model.executor_types import ComboOptions
|
||||||
from .component_model.files import canonicalize_path
|
from .component_model.files import canonicalize_path
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ from .component_model.files import canonicalize_path
|
|||||||
class UrlFile:
|
class UrlFile:
|
||||||
_url: str
|
_url: str
|
||||||
_save_with_filename: Optional[str] = None
|
_save_with_filename: Optional[str] = None
|
||||||
|
show_in_ui: Optional[bool] = True
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.save_with_filename
|
return self.save_with_filename
|
||||||
@ -58,6 +59,7 @@ class CivitFile:
|
|||||||
model_version_id: int
|
model_version_id: int
|
||||||
filename: str
|
filename: str
|
||||||
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
|
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
|
||||||
|
show_in_ui: Optional[bool] = True
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.filename
|
return self.filename
|
||||||
@ -95,7 +97,7 @@ class HuggingFile:
|
|||||||
return self.save_with_filename or split(self.filename)[-1]
|
return self.save_with_filename or split(self.filename)[-1]
|
||||||
|
|
||||||
|
|
||||||
class DownloadableFileList(ValidationView, list[str]):
|
class DownloadableFileList(ComboOptions, list[str]):
|
||||||
"""
|
"""
|
||||||
A list of downloadable files that can be validated differently than it will be serialized to JSON
|
A list of downloadable files that can be validated differently than it will be serialized to JSON
|
||||||
"""
|
"""
|
||||||
@ -116,8 +118,8 @@ class DownloadableFileList(ValidationView, list[str]):
|
|||||||
|
|
||||||
self.extend(sorted(list(map(canonicalize_path, ui_view))))
|
self.extend(sorted(list(map(canonicalize_path, ui_view))))
|
||||||
|
|
||||||
def view_for_validation(self) -> Iterable[str]:
|
def view_for_validation(self) -> list[str]:
|
||||||
return self._validation_view
|
return sorted(list(frozenset(self._validation_view) | frozenset(self)))
|
||||||
|
|
||||||
|
|
||||||
class CivitStats(TypedDict):
|
class CivitStats(TypedDict):
|
||||||
|
|||||||
@ -772,7 +772,7 @@ class VAELoader:
|
|||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_taesd(name):
|
def load_taesd(name: str):
|
||||||
sd_ = {}
|
sd_ = {}
|
||||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||||
|
|
||||||
|
|||||||
@ -2,12 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Optional, Type, Literal
|
from typing import Optional, Type, Literal, Union, NamedTuple
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||||
DependencyExecutionErrorMessage
|
DependencyExecutionErrorMessage, ComboOptions
|
||||||
from comfy.nodes_context import get_nodes
|
from comfy.nodes_context import get_nodes
|
||||||
from .graph_utils import is_link
|
from .graph_utils import is_link
|
||||||
|
|
||||||
@ -59,11 +60,21 @@ class DynamicPrompt:
|
|||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class InputInfoTuple(NamedTuple):
|
||||||
|
# the input types can be a name or a combobox
|
||||||
|
input_type: str | list[int] | ComboOptions
|
||||||
|
input_category: Literal["required", "optional", "hidden"]
|
||||||
|
extra_info: InputTypeOptions
|
||||||
|
|
||||||
|
|
||||||
|
GetInputInfoResponse = InputInfoTuple | tuple[None, None, None]
|
||||||
|
|
||||||
|
|
||||||
def get_input_info(
|
def get_input_info(
|
||||||
class_def: Type[ComfyNodeABC],
|
class_def: Union[Type[ComfyNodeABC], CustomNode],
|
||||||
input_name: str,
|
input_name: str,
|
||||||
valid_inputs: InputTypeDict | None = None
|
valid_inputs: Union[InputTypeDict, InputTypes] = None
|
||||||
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
) -> GetInputInfoResponse:
|
||||||
"""Get the input type, category, and extra info for a given input name.
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|||||||
117
tests/unit/test_approx_vae.py
Normal file
117
tests/unit/test_approx_vae.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import pathlib
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
|
||||||
|
|
||||||
|
def build_workflow(vae_encoder_option="", vae_decoder_option=""):
|
||||||
|
return {
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": vae_encoder_option
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"inputs": {
|
||||||
|
"pixels": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Encode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
"batch_size": 1,
|
||||||
|
"color": 0
|
||||||
|
},
|
||||||
|
"class_type": "EmptyImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "EmptyImage"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": vae_decoder_option
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "test",
|
||||||
|
"images": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_approx_vae_found(tmp_path_factory):
|
||||||
|
workflow = build_workflow("taesdxl", "taesdxl")
|
||||||
|
configuration = Configuration()
|
||||||
|
base_dir = tmp_path_factory.mktemp("test_approx_vae_found")
|
||||||
|
configuration.base_directory = str(base_dir)
|
||||||
|
async with Comfy(configuration) as comfy:
|
||||||
|
from comfy.nodes.base_nodes import VAELoader
|
||||||
|
|
||||||
|
assert "taesdxl" not in VAELoader.vae_list(), "should not be downloadable"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# should not attempt to download
|
||||||
|
await comfy.queue_prompt_api(workflow)
|
||||||
|
|
||||||
|
# download both vaes
|
||||||
|
taesdxl_decoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
|
||||||
|
taesdxl_encoder_path = hf_hub_download("madebyollin/taesdxl", "taesdxl_encoder.safetensors")
|
||||||
|
assert taesdxl_decoder_path is not None
|
||||||
|
assert taesdxl_encoder_path is not None
|
||||||
|
|
||||||
|
vae_approx_dir = base_dir / "models" / "vae_approx"
|
||||||
|
vae_approx_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
shutil.copy(taesdxl_encoder_path, vae_approx_dir / "taesdxl_encoder.safetensors")
|
||||||
|
shutil.copy(taesdxl_decoder_path, vae_approx_dir / "taesdxl_decoder.safetensors")
|
||||||
|
|
||||||
|
# now should work
|
||||||
|
await comfy.queue_prompt_api(workflow)
|
||||||
Loading…
Reference in New Issue
Block a user