From 2cad5ec0d6fdbff0b104b78c8e19637d7ca99b7d Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 28 Jun 2024 17:06:29 -0700 Subject: [PATCH] LoRA test --- .github/workflows/test.yml | 4 + comfy/client/embedded_comfy_client.py | 6 + comfy/model_downloader.py | 2 +- tests/workflows/__init__.py | 0 tests/workflows/lora.py | 154 ++++++++++++++++++++++++++ 5 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 tests/workflows/__init__.py create mode 100644 tests/workflows/lora.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d4a978bd5..26d7630d8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,6 +41,10 @@ jobs: run: | export HSA_OVERRIDE_GFX_VERSION=11.0.0 pytest -v tests/unit + - name: Run lora workflow + run: | + export HSA_OVERRIDE_GFX_VERSION=11.0.0 + pytest -v tests/workflows - name: Lint for errors run: | pylint comfy \ No newline at end of file diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index a0ffea630..b2dff0d46 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -130,6 +130,12 @@ class EmbeddedComfyClient: from ..cmd.execution import PromptExecutor, validate_prompt prompt_mut = make_mutable(prompt) validation_tuple = validate_prompt(prompt_mut) + if not validation_tuple[0]: + span.set_status(Status(StatusCode.ERROR)) + validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""} + error = ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]])) + span.record_exception(error) + return {} prompt_executor: PromptExecutor = self._prompt_executor diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 4f7ea22b9..e4ae12377 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -24,7 +24,7 @@ _session = Session() def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]: existing = frozenset(folder_paths.get_filename_list(folder_name)) - downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files if not isinstance(f, HuggingFile) or f.show_in_ui) + downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files) return sorted(list(existing | downloadable)) diff --git a/tests/workflows/__init__.py b/tests/workflows/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/workflows/lora.py b/tests/workflows/lora.py new file mode 100644 index 000000000..027fcc239 --- /dev/null +++ b/tests/workflows/lora.py @@ -0,0 +1,154 @@ +import pytest +import torch + +from comfy import model_management +from comfy.api.components.schema.prompt import Prompt +from comfy.model_downloader import add_known_models, KNOWN_LORAS +from comfy.model_downloader_types import CivitFile +from comfy.model_management import CPUState + +try: + has_gpu = torch.device(torch.cuda.current_device()) is not None +except: + has_gpu = False + +model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU +from comfy.client.embedded_comfy_client import EmbeddedComfyClient + + +@pytest.mark.skipif(not has_gpu, reason="requires gpu for performant testing") +@pytest.mark.asyncio +async def test_lora_workflow(): + prompt = Prompt.validate({ + "3": { + "inputs": { + "seed": 851616030078638, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "10", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "masterpiece best quality girl", + "clip": [ + "10", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "bad hands", + "clip": [ + "10", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "10": { + "inputs": { + "lora_name": "epi_noiseoffset2.safetensors", + "strength_model": 1, + "strength_clip": 1, + "model": [ + "4", + 0 + ], + "clip": [ + "4", + 1 + ] + }, + "class_type": "LoraLoader", + "_meta": { + "title": "Load LoRA" + } + } + }) + + add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors")) + async with EmbeddedComfyClient() as client: + outputs = await client.queue_prompt(prompt) + + save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage") + assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None