mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
- move workflows to distinct json files - add the comfy-org workflows for testing - fix issues where workflows from windows users would not be compatible with backends running on linux or macos in light of separator differences. Because this codebase uses get_or_download wherever checkpoints, models, etc. are used, this is the only place where the comparison is gracefully handled for downloading. Validation code will correctly convert backslashes to forward slashes, assuming that 100% of the places they are used and when comparing with a list, they are intended to be paths and not strict symbols
51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
import importlib.resources
|
|
import json
|
|
from importlib.abc import Traversable
|
|
|
|
import pytest
|
|
|
|
from comfy.api.components.schema.prompt import Prompt
|
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
|
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
|
from comfy.model_downloader_types import CivitFile
|
|
from . import workflows
|
|
|
|
|
|
@pytest.fixture(scope="module", autouse=False)
|
|
@pytest.mark.asyncio
|
|
async def client(tmp_path_factory) -> EmbeddedComfyClient:
|
|
async with EmbeddedComfyClient() as client:
|
|
yield client
|
|
|
|
|
|
def _prepare_for_workflows() -> dict[str, Traversable]:
|
|
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
|
|
|
|
return {f.name: f for f in importlib.resources.files(workflows).iterdir() if f.is_file() and f.name.endswith(".json")}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("workflow_name, workflow_file", _prepare_for_workflows().items())
|
|
async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: bool, client: EmbeddedComfyClient):
|
|
if not has_gpu:
|
|
pytest.skip("requires gpu")
|
|
|
|
if "audio" in workflow_name:
|
|
try:
|
|
import torchaudio
|
|
except (ImportError, ModuleNotFoundError):
|
|
pytest.skip("requires torchaudio")
|
|
|
|
workflow = json.loads(workflow_file.read_text())
|
|
|
|
prompt = Prompt.validate(workflow)
|
|
# todo: add all the models we want to test a bit more elegantly
|
|
outputs = await client.queue_prompt(prompt)
|
|
|
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
|
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
|
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
|
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None
|