ComfyUI/tests/inference/test_execution.py
Jacob Segal 6d09dd70f8 Make custom VALIDATE_INPUTS skip normal validation
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.
2024-02-24 23:17:01 -08:00

356 lines
15 KiB
Python

from io import BytesIO
import numpy
from PIL import Image
import pytest
from pytest import fixture
import time
import torch
from typing import Union, Dict
import json
import subprocess
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import urllib.request
import urllib.parse
import urllib.error
from comfy.graph_utils import GraphBuilder, Node
class RunResult:
def __init__(self, prompt_id: str):
self.outputs: Dict[str,Dict] = {}
self.runs: Dict[str,bool] = {}
self.prompt_id: str = prompt_id
def get_output(self, node: Node):
return self.outputs.get(node.id, None)
def did_run(self, node: Node):
return self.runs.get(node.id, False)
def get_images(self, node: Node):
output = self.get_output(node)
if output is None:
return []
return output.get('image_objects', [])
def get_prompt_id(self):
return self.prompt_id
class ComfyClient:
def __init__(self):
self.test_name = ""
def connect(self,
listen:str = '127.0.0.1',
port:Union[str,int] = 8188,
client_id: str = str(uuid.uuid4())
):
self.client_id = client_id
self.server_address = f"{listen}:{port}"
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
self.ws = ws
def queue_prompt(self, prompt):
p = {"prompt": prompt, "client_id": self.client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(self, filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
return response.read()
def get_history(self, prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
return json.loads(response.read())
def set_test_name(self, name):
self.test_name = name
def run(self, graph):
prompt = graph.finalize()
for node in graph.nodes.values():
if node.class_type == 'SaveImage':
node.inputs['filename_prefix'] = self.test_name
prompt_id = self.queue_prompt(prompt)['prompt_id']
result = RunResult(prompt_id)
while True:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] != prompt_id:
continue
if data['node'] is None:
break
result.runs[data['node']] = True
elif message['type'] == 'execution_error':
raise Exception(message['data'])
elif message['type'] == 'execution_cached':
pass # Probably want to store this off for testing
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
return result
#
# Loop through these variables
#
@pytest.mark.execution
class TestExecution:
#
# Initialize server and client
#
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
# Start server
p = subprocess.Popen([
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
])
yield
p.kill()
torch.cuda.empty_cache()
def start_client(self, listen:str, port:int):
# Start client
comfy_client = ComfyClient()
# Connect to server (with retries)
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
comfy_client.connect(listen=listen, port=port)
except ConnectionRefusedError as e:
print(e)
print(f"({i+1}/{n_tries}) Retrying...")
else:
break
return comfy_client
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
client = self.start_client(args_pytest["listen"], args_pytest["port"])
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
shared_client.set_test_name(f"execution[{request.node.name}]")
yield shared_client
def clear_cache(self, client: ComfyClient):
g = GraphBuilder(prefix="foo")
random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1)
g.node("PreviewImage", images=random.out(0))
client.run(g)
@fixture
def builder(self):
yield GraphBuilder(prefix="")
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
output = g.node("SaveImage", images=lazy_mix.out(0))
result = client.run(g)
result_image = result.get_images(output)[0]
assert numpy.array(result_image).any() == 0, "Image should be black"
assert result.did_run(input1)
assert not result.did_run(input2)
assert result.did_run(mask)
assert result.did_run(lazy_mix)
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
result1 = client.run(g)
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert result1.did_run(node), f"Node {node_id} didn't run"
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
result1 = client.run(g)
mask.inputs['value'] = 0.4
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert result1.did_run(node), f"Node {node_id} didn't run"
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
assert result2.did_run(mask), "Mask should have been re-run"
assert result2.did_run(lazy_mix), "Lazy mix should have been re-run"
def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Different size of the two images
input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
try:
client.run(g)
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
@pytest.mark.parametrize("test_value, expect_error", [
(5, True),
("foo", True),
(5.0, False),
])
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5),
("StubFloat", 5.0)
])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation1.out(0))
with pytest.raises(urllib.error.HTTPError):
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation2.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation3.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
save = g.node("SaveImage")
is_changed = g.node("TestCustomIsChanged", should_change=False)
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
save.set_input('images', is_changed.out(0))
is_changed.set_input('image', input1.out(0))
result1 = client.run(g)
result2 = client.run(g)
is_changed.set_input('should_change', True)
result3 = client.run(g)
result4 = client.run(g)
assert result1.did_run(is_changed), "is_changed should have been run"
assert not result2.did_run(is_changed), "is_changed should have been cached"
assert result3.did_run(is_changed), "is_changed should have been re-run"
assert result4.did_run(is_changed), "is_changed should not have been cached"
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0))
output = g.node("SaveImage", images=average.out(0))
result = client.run(g)
result_image = result.get_images(output)[0]
expected = 255 // 4
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
assert result.did_run(input1)
assert result.did_run(input2)
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
g = builder
iterations = 4
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0))
for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0))
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2))
for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0))
output = g.node("SaveImage", images=for_close.out(0))
for iterations in range(1, 5):
for_open.set_input('remaining', iterations)
result = client.run(g)
result_image = result.get_images(output)[0]
expected = 255 // (2 ** iterations)
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
assert result.did_run(is_changed)