From 87d1f309026f27b982395676f9462c847a3a5367 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 16 May 2024 15:01:51 -0700 Subject: [PATCH] Some base nodes now have unit tests --- comfy/nodes/base_nodes.py | 3 +- tests/unit/test_base_nodes.py | 159 +++++++++++++++++++++++++++++++ tests/unit/test_openapi_nodes.py | 2 +- 3 files changed, 162 insertions(+), 2 deletions(-) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 5bcae2868..ec521edef 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -31,6 +31,7 @@ from ..nodes.common import MAX_RESOLUTION from .. import controlnet from ..open_exr import load_exr from .. import node_helpers +from ..sd import VAE from ..utils import comfy_tqdm @@ -280,7 +281,7 @@ class VAEEncode: CATEGORY = "latent" - def encode(self, vae, pixels): + def encode(self, vae: VAE, pixels): t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) diff --git a/tests/unit/test_base_nodes.py b/tests/unit/test_base_nodes.py index e69de29bb..8322d4b5c 100644 --- a/tests/unit/test_base_nodes.py +++ b/tests/unit/test_base_nodes.py @@ -0,0 +1,159 @@ +import numpy as np +import pytest +import torch + +from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \ + LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \ + VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \ + CheckpointLoaderSimple, VAELoader, EmptyImage + +torch.set_grad_enabled(False) + +_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8) +_image_1x1 = torch.ones((1, 1, 1, 3), dtype=torch.float32, device="cpu") +_image_512x512 = torch.randn((1,512,512,3) , dtype=torch.float32, device="cpu") + +_cond = torch.randn((1, 4, 77, 768)) +_cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))}) + +_latent = {"samples": torch.randn((1, 4, 64, 64))} + + +@pytest.fixture(scope="module") +def vae(): + vae, = VAELoader().load_vae("vae-ft-mse-840000-ema-pruned.safetensors") + return vae + + +@pytest.fixture(scope="module") +def clip(vae): + return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[1] + + +@pytest.fixture(scope="module") +def model(clip, vae): + return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[0] + + +def test_clip_text_encode(clip): + cond, = CLIPTextEncode().encode(clip, "test prompt") + assert len(cond) == 1 + assert cond[0][0].shape == (1, 77, 768) + assert "pooled_output" in cond[0][1] + assert cond[0][1]["pooled_output"].shape == (1, 768) + + +def test_conditioning_combine(): + cond = ConditioningCombine().combine(_cond_with_pooled, _cond_with_pooled) + assert len(cond) == 1 + assert cond[0][0].shape == (1, 4, 77, 768) + + +def test_conditioning_set_area(clip): + cond, = CLIPTextEncode().encode(clip, "test prompt") + cond, = ConditioningSetArea().append(cond, 64, 64, 0, 0, 1.0) + assert len(cond) == 1 + assert cond[0][1]["area"] == (8, 8, 0, 0) + assert cond[0][1]["strength"] == 1.0 + + +def test_conditioning_set_mask(clip): + cond, = CLIPTextEncode().encode(clip, "test prompt") + mask = torch.ones((1, 64, 64)) + cond, = ConditioningSetMask().append(cond, mask, "default", 1.0) + assert len(cond) == 1 + assert torch.equal(cond[0][1]["mask"], mask) + assert cond[0][1]["mask_strength"] == 1.0 + + +def test_vae_decode(vae): + decoded, = VAEDecode().decode(vae, _latent) + assert decoded.shape == (1, 512, 512, 3) + + +def test_vae_encode(vae): + latent, = VAEEncode().encode(vae, _image_512x512) + assert "samples" in latent + assert latent["samples"].shape == (1, 4, 64, 64) + + +def test_vae_encode_for_inpaint(vae): + mask = torch.ones((1, 512, 512)) + latent, = VAEEncodeForInpaint().encode(vae, _image_512x512, mask) + assert "samples" in latent + assert latent["samples"].shape == (1, 4, 64, 64) + assert "noise_mask" in latent + assert torch.allclose(latent["noise_mask"], mask) + + +def test_inpaint_model_conditioning(model, vae, clip): + cond_pos, = CLIPTextEncode().encode(clip, "test prompt") + cond_neg, = CLIPTextEncode().encode(clip, "test negative prompt") + pos, neg, latent = InpaintModelConditioning().encode(cond_pos, cond_neg, _image_512x512, vae, torch.ones((1, 512, 512))) + assert len(pos) == len(cond_pos) + assert len(neg) == len(cond_neg) + assert "samples" in latent + assert "noise_mask" in latent + + +def test_latent_upscale(): + latent, = LatentUpscale().upscale(_latent, "nearest-exact", 1024, 1024, "disabled") + assert latent["samples"].shape == (1, 4, 128, 128) + + +def test_latent_upscale_by(): + latent, = LatentUpscaleBy().upscale(_latent, "nearest-exact", 2.0) + assert latent["samples"].shape == (1, 4, 128, 128) + + +def test_latent_rotate(): + latent, = LatentRotate().rotate(_latent, "90 degrees") + assert latent["samples"].shape == (1, 4, 64, 64) + + +def test_latent_flip(): + latent, = LatentFlip().flip(_latent, "y-axis: horizontally") + assert latent["samples"].shape == (1, 4, 64, 64) + + +def test_latent_composite(): + latent, = LatentComposite().composite(_latent, _latent, 0, 0) + assert latent["samples"].shape == (1, 4, 64, 64) + + +def test_latent_crop(): + latent, = LatentCrop().crop(_latent, 32, 32, 0, 0) + assert latent["samples"].shape == (1, 4, 4, 4) + + +def test_image_scale(): + image, = ImageScale().upscale(_image_1x1, "nearest-exact", 64, 64, "disabled") + assert image.shape == (1, 64, 64, 3) + + +def test_image_scale_by(): + image, = ImageScaleBy().upscale(_image_1x1, "nearest-exact", 2.0) + assert image.shape == (1, 2, 2, 3) + + +def test_image_invert(): + image, = ImageInvert().invert(_image_1x1) + assert image.shape == (1, 1, 1, 3) + assert torch.allclose(image, 1.0 - _image_1x1) + + +def test_image_batch(): + image, = ImageBatch().batch(_image_1x1, _image_1x1) + assert image.shape == (2, 1, 1, 3) + + +def test_image_pad_for_outpaint(): + padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0) + assert padded.shape == (1, 3, 3, 3) + assert mask.shape == (3, 3) + + +def test_empty_image(): + image, = EmptyImage().generate(64, 64, 1, 0xFF0000) + assert image.shape == (1, 64, 64, 3) + assert torch.allclose(image[0, 0, 0], torch.tensor([1.0, 0.0, 0.0])) diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 09bab96ef..8ef5665d4 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -226,7 +226,7 @@ def test_image_exif_merge(): @freeze_time("2024-01-14 03:21:34", tz_offset=-4) -@pytest.mark.skipif(sys.platform == 'win32') +@pytest.mark.skipif(sys.platform == 'win32', reason="Windows does not have reliable time freezing") def test_image_exif_creation_date_and_batch_number(): assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None n = ImageExifCreationDateAndBatchNumber()