Some base nodes now have unit tests

This commit is contained in:
doctorpangloss 2024-05-16 15:01:51 -07:00
parent 3d98440fb7
commit 87d1f30902
3 changed files with 162 additions and 2 deletions

View File

@ -31,6 +31,7 @@ from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
from .. import node_helpers
from ..sd import VAE
from ..utils import comfy_tqdm
@ -280,7 +281,7 @@ class VAEEncode:
CATEGORY = "latent"
def encode(self, vae, pixels):
def encode(self, vae: VAE, pixels):
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )

View File

@ -0,0 +1,159 @@
import numpy as np
import pytest
import torch
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
CheckpointLoaderSimple, VAELoader, EmptyImage
torch.set_grad_enabled(False)
_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8)
_image_1x1 = torch.ones((1, 1, 1, 3), dtype=torch.float32, device="cpu")
_image_512x512 = torch.randn((1,512,512,3) , dtype=torch.float32, device="cpu")
_cond = torch.randn((1, 4, 77, 768))
_cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
_latent = {"samples": torch.randn((1, 4, 64, 64))}
@pytest.fixture(scope="module")
def vae():
vae, = VAELoader().load_vae("vae-ft-mse-840000-ema-pruned.safetensors")
return vae
@pytest.fixture(scope="module")
def clip(vae):
return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[1]
@pytest.fixture(scope="module")
def model(clip, vae):
return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[0]
def test_clip_text_encode(clip):
cond, = CLIPTextEncode().encode(clip, "test prompt")
assert len(cond) == 1
assert cond[0][0].shape == (1, 77, 768)
assert "pooled_output" in cond[0][1]
assert cond[0][1]["pooled_output"].shape == (1, 768)
def test_conditioning_combine():
cond = ConditioningCombine().combine(_cond_with_pooled, _cond_with_pooled)
assert len(cond) == 1
assert cond[0][0].shape == (1, 4, 77, 768)
def test_conditioning_set_area(clip):
cond, = CLIPTextEncode().encode(clip, "test prompt")
cond, = ConditioningSetArea().append(cond, 64, 64, 0, 0, 1.0)
assert len(cond) == 1
assert cond[0][1]["area"] == (8, 8, 0, 0)
assert cond[0][1]["strength"] == 1.0
def test_conditioning_set_mask(clip):
cond, = CLIPTextEncode().encode(clip, "test prompt")
mask = torch.ones((1, 64, 64))
cond, = ConditioningSetMask().append(cond, mask, "default", 1.0)
assert len(cond) == 1
assert torch.equal(cond[0][1]["mask"], mask)
assert cond[0][1]["mask_strength"] == 1.0
def test_vae_decode(vae):
decoded, = VAEDecode().decode(vae, _latent)
assert decoded.shape == (1, 512, 512, 3)
def test_vae_encode(vae):
latent, = VAEEncode().encode(vae, _image_512x512)
assert "samples" in latent
assert latent["samples"].shape == (1, 4, 64, 64)
def test_vae_encode_for_inpaint(vae):
mask = torch.ones((1, 512, 512))
latent, = VAEEncodeForInpaint().encode(vae, _image_512x512, mask)
assert "samples" in latent
assert latent["samples"].shape == (1, 4, 64, 64)
assert "noise_mask" in latent
assert torch.allclose(latent["noise_mask"], mask)
def test_inpaint_model_conditioning(model, vae, clip):
cond_pos, = CLIPTextEncode().encode(clip, "test prompt")
cond_neg, = CLIPTextEncode().encode(clip, "test negative prompt")
pos, neg, latent = InpaintModelConditioning().encode(cond_pos, cond_neg, _image_512x512, vae, torch.ones((1, 512, 512)))
assert len(pos) == len(cond_pos)
assert len(neg) == len(cond_neg)
assert "samples" in latent
assert "noise_mask" in latent
def test_latent_upscale():
latent, = LatentUpscale().upscale(_latent, "nearest-exact", 1024, 1024, "disabled")
assert latent["samples"].shape == (1, 4, 128, 128)
def test_latent_upscale_by():
latent, = LatentUpscaleBy().upscale(_latent, "nearest-exact", 2.0)
assert latent["samples"].shape == (1, 4, 128, 128)
def test_latent_rotate():
latent, = LatentRotate().rotate(_latent, "90 degrees")
assert latent["samples"].shape == (1, 4, 64, 64)
def test_latent_flip():
latent, = LatentFlip().flip(_latent, "y-axis: horizontally")
assert latent["samples"].shape == (1, 4, 64, 64)
def test_latent_composite():
latent, = LatentComposite().composite(_latent, _latent, 0, 0)
assert latent["samples"].shape == (1, 4, 64, 64)
def test_latent_crop():
latent, = LatentCrop().crop(_latent, 32, 32, 0, 0)
assert latent["samples"].shape == (1, 4, 4, 4)
def test_image_scale():
image, = ImageScale().upscale(_image_1x1, "nearest-exact", 64, 64, "disabled")
assert image.shape == (1, 64, 64, 3)
def test_image_scale_by():
image, = ImageScaleBy().upscale(_image_1x1, "nearest-exact", 2.0)
assert image.shape == (1, 2, 2, 3)
def test_image_invert():
image, = ImageInvert().invert(_image_1x1)
assert image.shape == (1, 1, 1, 3)
assert torch.allclose(image, 1.0 - _image_1x1)
def test_image_batch():
image, = ImageBatch().batch(_image_1x1, _image_1x1)
assert image.shape == (2, 1, 1, 3)
def test_image_pad_for_outpaint():
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
assert padded.shape == (1, 3, 3, 3)
assert mask.shape == (3, 3)
def test_empty_image():
image, = EmptyImage().generate(64, 64, 1, 0xFF0000)
assert image.shape == (1, 64, 64, 3)
assert torch.allclose(image[0, 0, 0], torch.tensor([1.0, 0.0, 0.0]))

View File

@ -226,7 +226,7 @@ def test_image_exif_merge():
@freeze_time("2024-01-14 03:21:34", tz_offset=-4)
@pytest.mark.skipif(sys.platform == 'win32')
@pytest.mark.skipif(sys.platform == 'win32', reason="Windows does not have reliable time freezing")
def test_image_exif_creation_date_and_batch_number():
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
n = ImageExifCreationDateAndBatchNumber()