LoRA test

This commit is contained in:
doctorpangloss 2024-06-28 17:06:29 -07:00
parent 531a2c879d
commit 2cad5ec0d6
5 changed files with 165 additions and 1 deletions

View File

@ -41,6 +41,10 @@ jobs:
run: | run: |
export HSA_OVERRIDE_GFX_VERSION=11.0.0 export HSA_OVERRIDE_GFX_VERSION=11.0.0
pytest -v tests/unit pytest -v tests/unit
- name: Run lora workflow
run: |
export HSA_OVERRIDE_GFX_VERSION=11.0.0
pytest -v tests/workflows
- name: Lint for errors - name: Lint for errors
run: | run: |
pylint comfy pylint comfy

View File

@ -130,6 +130,12 @@ class EmbeddedComfyClient:
from ..cmd.execution import PromptExecutor, validate_prompt from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt) prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut) validation_tuple = validate_prompt(prompt_mut)
if not validation_tuple[0]:
span.set_status(Status(StatusCode.ERROR))
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
error = ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
span.record_exception(error)
return {}
prompt_executor: PromptExecutor = self._prompt_executor prompt_executor: PromptExecutor = self._prompt_executor

View File

@ -24,7 +24,7 @@ _session = Session()
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]: def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
existing = frozenset(folder_paths.get_filename_list(folder_name)) existing = frozenset(folder_paths.get_filename_list(folder_name))
downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files if not isinstance(f, HuggingFile) or f.show_in_ui) downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files)
return sorted(list(existing | downloadable)) return sorted(list(existing | downloadable))

View File

154
tests/workflows/lora.py Normal file
View File

@ -0,0 +1,154 @@
import pytest
import torch
from comfy import model_management
from comfy.api.components.schema.prompt import Prompt
from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
from comfy.model_management import CPUState
try:
has_gpu = torch.device(torch.cuda.current_device()) is not None
except:
has_gpu = False
model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
@pytest.mark.skipif(not has_gpu, reason="requires gpu for performant testing")
@pytest.mark.asyncio
async def test_lora_workflow():
prompt = Prompt.validate({
"3": {
"inputs": {
"seed": 851616030078638,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"10",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "masterpiece best quality girl",
"clip": [
"10",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "bad hands",
"clip": [
"10",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"10": {
"inputs": {
"lora_name": "epi_noiseoffset2.safetensors",
"strength_model": 1,
"strength_clip": 1,
"model": [
"4",
0
],
"clip": [
"4",
1
]
},
"class_type": "LoraLoader",
"_meta": {
"title": "Load LoRA"
}
}
})
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
async with EmbeddedComfyClient() as client:
outputs = await client.queue_prompt(prompt)
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None