ComfyUI/tests/inference/test_workflows.py
doctorpangloss 2bc95c1711 Test improvements and fixes
- move workflows to distinct json files
 - add the comfy-org workflows for testing
 - fix issues where workflows from windows users would not be compatible
   with backends running on linux or macos in light of separator
   differences. Because this codebase uses get_or_download wherever
   checkpoints, models, etc. are used, this is the only place where the
   comparison is gracefully handled for downloading. Validation code
   will correctly convert backslashes to forward slashes, assuming that
   100% of the places they are used and when comparing with a list, they
   are intended to be paths and not strict symbols
2024-08-05 15:55:46 -07:00

51 lines
1.9 KiB
Python

import importlib.resources
import json
from importlib.abc import Traversable
import pytest
from comfy.api.components.schema.prompt import Prompt
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
from . import workflows
@pytest.fixture(scope="module", autouse=False)
@pytest.mark.asyncio
async def client(tmp_path_factory) -> EmbeddedComfyClient:
async with EmbeddedComfyClient() as client:
yield client
def _prepare_for_workflows() -> dict[str, Traversable]:
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
return {f.name: f for f in importlib.resources.files(workflows).iterdir() if f.is_file() and f.name.endswith(".json")}
@pytest.mark.asyncio
@pytest.mark.parametrize("workflow_name, workflow_file", _prepare_for_workflows().items())
async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: bool, client: EmbeddedComfyClient):
if not has_gpu:
pytest.skip("requires gpu")
if "audio" in workflow_name:
try:
import torchaudio
except (ImportError, ModuleNotFoundError):
pytest.skip("requires torchaudio")
workflow = json.loads(workflow_file.read_text())
prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit more elegantly
outputs = await client.queue_prompt(prompt)
if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
assert outputs[save_image_node_id]["audio"][0]["filename"] is not None