mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 08:40:50 +08:00
Some base nodes now have unit tests
This commit is contained in:
parent
3d98440fb7
commit
87d1f30902
@ -31,6 +31,7 @@ from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
from .. import node_helpers
|
||||
from ..sd import VAE
|
||||
from ..utils import comfy_tqdm
|
||||
|
||||
|
||||
@ -280,7 +281,7 @@ class VAEEncode:
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
def encode(self, vae: VAE, pixels):
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
return ({"samples":t}, )
|
||||
|
||||
|
||||
@ -0,0 +1,159 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
|
||||
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
|
||||
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
|
||||
CheckpointLoaderSimple, VAELoader, EmptyImage
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8)
|
||||
_image_1x1 = torch.ones((1, 1, 1, 3), dtype=torch.float32, device="cpu")
|
||||
_image_512x512 = torch.randn((1,512,512,3) , dtype=torch.float32, device="cpu")
|
||||
|
||||
_cond = torch.randn((1, 4, 77, 768))
|
||||
_cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
|
||||
|
||||
_latent = {"samples": torch.randn((1, 4, 64, 64))}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vae():
|
||||
vae, = VAELoader().load_vae("vae-ft-mse-840000-ema-pruned.safetensors")
|
||||
return vae
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def clip(vae):
|
||||
return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[1]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model(clip, vae):
|
||||
return CheckpointLoaderSimple().load_checkpoint("v1-5-pruned-emaonly.safetensors")[0]
|
||||
|
||||
|
||||
def test_clip_text_encode(clip):
|
||||
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
||||
assert len(cond) == 1
|
||||
assert cond[0][0].shape == (1, 77, 768)
|
||||
assert "pooled_output" in cond[0][1]
|
||||
assert cond[0][1]["pooled_output"].shape == (1, 768)
|
||||
|
||||
|
||||
def test_conditioning_combine():
|
||||
cond = ConditioningCombine().combine(_cond_with_pooled, _cond_with_pooled)
|
||||
assert len(cond) == 1
|
||||
assert cond[0][0].shape == (1, 4, 77, 768)
|
||||
|
||||
|
||||
def test_conditioning_set_area(clip):
|
||||
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
||||
cond, = ConditioningSetArea().append(cond, 64, 64, 0, 0, 1.0)
|
||||
assert len(cond) == 1
|
||||
assert cond[0][1]["area"] == (8, 8, 0, 0)
|
||||
assert cond[0][1]["strength"] == 1.0
|
||||
|
||||
|
||||
def test_conditioning_set_mask(clip):
|
||||
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
||||
mask = torch.ones((1, 64, 64))
|
||||
cond, = ConditioningSetMask().append(cond, mask, "default", 1.0)
|
||||
assert len(cond) == 1
|
||||
assert torch.equal(cond[0][1]["mask"], mask)
|
||||
assert cond[0][1]["mask_strength"] == 1.0
|
||||
|
||||
|
||||
def test_vae_decode(vae):
|
||||
decoded, = VAEDecode().decode(vae, _latent)
|
||||
assert decoded.shape == (1, 512, 512, 3)
|
||||
|
||||
|
||||
def test_vae_encode(vae):
|
||||
latent, = VAEEncode().encode(vae, _image_512x512)
|
||||
assert "samples" in latent
|
||||
assert latent["samples"].shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
def test_vae_encode_for_inpaint(vae):
|
||||
mask = torch.ones((1, 512, 512))
|
||||
latent, = VAEEncodeForInpaint().encode(vae, _image_512x512, mask)
|
||||
assert "samples" in latent
|
||||
assert latent["samples"].shape == (1, 4, 64, 64)
|
||||
assert "noise_mask" in latent
|
||||
assert torch.allclose(latent["noise_mask"], mask)
|
||||
|
||||
|
||||
def test_inpaint_model_conditioning(model, vae, clip):
|
||||
cond_pos, = CLIPTextEncode().encode(clip, "test prompt")
|
||||
cond_neg, = CLIPTextEncode().encode(clip, "test negative prompt")
|
||||
pos, neg, latent = InpaintModelConditioning().encode(cond_pos, cond_neg, _image_512x512, vae, torch.ones((1, 512, 512)))
|
||||
assert len(pos) == len(cond_pos)
|
||||
assert len(neg) == len(cond_neg)
|
||||
assert "samples" in latent
|
||||
assert "noise_mask" in latent
|
||||
|
||||
|
||||
def test_latent_upscale():
|
||||
latent, = LatentUpscale().upscale(_latent, "nearest-exact", 1024, 1024, "disabled")
|
||||
assert latent["samples"].shape == (1, 4, 128, 128)
|
||||
|
||||
|
||||
def test_latent_upscale_by():
|
||||
latent, = LatentUpscaleBy().upscale(_latent, "nearest-exact", 2.0)
|
||||
assert latent["samples"].shape == (1, 4, 128, 128)
|
||||
|
||||
|
||||
def test_latent_rotate():
|
||||
latent, = LatentRotate().rotate(_latent, "90 degrees")
|
||||
assert latent["samples"].shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
def test_latent_flip():
|
||||
latent, = LatentFlip().flip(_latent, "y-axis: horizontally")
|
||||
assert latent["samples"].shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
def test_latent_composite():
|
||||
latent, = LatentComposite().composite(_latent, _latent, 0, 0)
|
||||
assert latent["samples"].shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
def test_latent_crop():
|
||||
latent, = LatentCrop().crop(_latent, 32, 32, 0, 0)
|
||||
assert latent["samples"].shape == (1, 4, 4, 4)
|
||||
|
||||
|
||||
def test_image_scale():
|
||||
image, = ImageScale().upscale(_image_1x1, "nearest-exact", 64, 64, "disabled")
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
def test_image_scale_by():
|
||||
image, = ImageScaleBy().upscale(_image_1x1, "nearest-exact", 2.0)
|
||||
assert image.shape == (1, 2, 2, 3)
|
||||
|
||||
|
||||
def test_image_invert():
|
||||
image, = ImageInvert().invert(_image_1x1)
|
||||
assert image.shape == (1, 1, 1, 3)
|
||||
assert torch.allclose(image, 1.0 - _image_1x1)
|
||||
|
||||
|
||||
def test_image_batch():
|
||||
image, = ImageBatch().batch(_image_1x1, _image_1x1)
|
||||
assert image.shape == (2, 1, 1, 3)
|
||||
|
||||
|
||||
def test_image_pad_for_outpaint():
|
||||
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
|
||||
assert padded.shape == (1, 3, 3, 3)
|
||||
assert mask.shape == (3, 3)
|
||||
|
||||
|
||||
def test_empty_image():
|
||||
image, = EmptyImage().generate(64, 64, 1, 0xFF0000)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
assert torch.allclose(image[0, 0, 0], torch.tensor([1.0, 0.0, 0.0]))
|
||||
@ -226,7 +226,7 @@ def test_image_exif_merge():
|
||||
|
||||
|
||||
@freeze_time("2024-01-14 03:21:34", tz_offset=-4)
|
||||
@pytest.mark.skipif(sys.platform == 'win32')
|
||||
@pytest.mark.skipif(sys.platform == 'win32', reason="Windows does not have reliable time freezing")
|
||||
def test_image_exif_creation_date_and_batch_number():
|
||||
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
|
||||
n = ImageExifCreationDateAndBatchNumber()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user