mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`, that variable will be a dictionary of the socket type of all incoming connections. If that argument exists, normal socket type validation will not occur. This removes the last hurdle for enabling variant types entirely from custom nodes, so I've removed that command-line option. I've added appropriate unit tests for these changes.
356 lines
15 KiB
Python
356 lines
15 KiB
Python
from io import BytesIO
|
|
import numpy
|
|
from PIL import Image
|
|
import pytest
|
|
from pytest import fixture
|
|
import time
|
|
import torch
|
|
from typing import Union, Dict
|
|
import json
|
|
import subprocess
|
|
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
|
import uuid
|
|
import urllib.request
|
|
import urllib.parse
|
|
import urllib.error
|
|
from comfy.graph_utils import GraphBuilder, Node
|
|
|
|
class RunResult:
|
|
def __init__(self, prompt_id: str):
|
|
self.outputs: Dict[str,Dict] = {}
|
|
self.runs: Dict[str,bool] = {}
|
|
self.prompt_id: str = prompt_id
|
|
|
|
def get_output(self, node: Node):
|
|
return self.outputs.get(node.id, None)
|
|
|
|
def did_run(self, node: Node):
|
|
return self.runs.get(node.id, False)
|
|
|
|
def get_images(self, node: Node):
|
|
output = self.get_output(node)
|
|
if output is None:
|
|
return []
|
|
return output.get('image_objects', [])
|
|
|
|
def get_prompt_id(self):
|
|
return self.prompt_id
|
|
|
|
class ComfyClient:
|
|
def __init__(self):
|
|
self.test_name = ""
|
|
|
|
def connect(self,
|
|
listen:str = '127.0.0.1',
|
|
port:Union[str,int] = 8188,
|
|
client_id: str = str(uuid.uuid4())
|
|
):
|
|
self.client_id = client_id
|
|
self.server_address = f"{listen}:{port}"
|
|
ws = websocket.WebSocket()
|
|
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
|
self.ws = ws
|
|
|
|
def queue_prompt(self, prompt):
|
|
p = {"prompt": prompt, "client_id": self.client_id}
|
|
data = json.dumps(p).encode('utf-8')
|
|
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
|
return json.loads(urllib.request.urlopen(req).read())
|
|
|
|
def get_image(self, filename, subfolder, folder_type):
|
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
url_values = urllib.parse.urlencode(data)
|
|
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
|
return response.read()
|
|
|
|
def get_history(self, prompt_id):
|
|
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
|
return json.loads(response.read())
|
|
|
|
def set_test_name(self, name):
|
|
self.test_name = name
|
|
|
|
def run(self, graph):
|
|
prompt = graph.finalize()
|
|
for node in graph.nodes.values():
|
|
if node.class_type == 'SaveImage':
|
|
node.inputs['filename_prefix'] = self.test_name
|
|
|
|
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
|
result = RunResult(prompt_id)
|
|
while True:
|
|
out = self.ws.recv()
|
|
if isinstance(out, str):
|
|
message = json.loads(out)
|
|
if message['type'] == 'executing':
|
|
data = message['data']
|
|
if data['prompt_id'] != prompt_id:
|
|
continue
|
|
if data['node'] is None:
|
|
break
|
|
result.runs[data['node']] = True
|
|
elif message['type'] == 'execution_error':
|
|
raise Exception(message['data'])
|
|
elif message['type'] == 'execution_cached':
|
|
pass # Probably want to store this off for testing
|
|
|
|
history = self.get_history(prompt_id)[prompt_id]
|
|
for o in history['outputs']:
|
|
for node_id in history['outputs']:
|
|
node_output = history['outputs'][node_id]
|
|
result.outputs[node_id] = node_output
|
|
if 'images' in node_output:
|
|
images_output = []
|
|
for image in node_output['images']:
|
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
|
image_obj = Image.open(BytesIO(image_data))
|
|
images_output.append(image_obj)
|
|
node_output['image_objects'] = images_output
|
|
|
|
return result
|
|
|
|
#
|
|
# Loop through these variables
|
|
#
|
|
@pytest.mark.execution
|
|
class TestExecution:
|
|
#
|
|
# Initialize server and client
|
|
#
|
|
@fixture(scope="class", autouse=True)
|
|
def _server(self, args_pytest):
|
|
# Start server
|
|
p = subprocess.Popen([
|
|
'python','main.py',
|
|
'--output-directory', args_pytest["output_dir"],
|
|
'--listen', args_pytest["listen"],
|
|
'--port', str(args_pytest["port"]),
|
|
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
|
])
|
|
yield
|
|
p.kill()
|
|
torch.cuda.empty_cache()
|
|
|
|
def start_client(self, listen:str, port:int):
|
|
# Start client
|
|
comfy_client = ComfyClient()
|
|
# Connect to server (with retries)
|
|
n_tries = 5
|
|
for i in range(n_tries):
|
|
time.sleep(4)
|
|
try:
|
|
comfy_client.connect(listen=listen, port=port)
|
|
except ConnectionRefusedError as e:
|
|
print(e)
|
|
print(f"({i+1}/{n_tries}) Retrying...")
|
|
else:
|
|
break
|
|
return comfy_client
|
|
|
|
@fixture(scope="class", autouse=True)
|
|
def shared_client(self, args_pytest, _server):
|
|
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
|
yield client
|
|
del client
|
|
torch.cuda.empty_cache()
|
|
|
|
@fixture
|
|
def client(self, shared_client, request):
|
|
shared_client.set_test_name(f"execution[{request.node.name}]")
|
|
yield shared_client
|
|
|
|
def clear_cache(self, client: ComfyClient):
|
|
g = GraphBuilder(prefix="foo")
|
|
random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1)
|
|
g.node("PreviewImage", images=random.out(0))
|
|
client.run(g)
|
|
|
|
@fixture
|
|
def builder(self):
|
|
yield GraphBuilder(prefix="")
|
|
|
|
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
|
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
|
|
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
|
output = g.node("SaveImage", images=lazy_mix.out(0))
|
|
result = client.run(g)
|
|
|
|
result_image = result.get_images(output)[0]
|
|
assert numpy.array(result_image).any() == 0, "Image should be black"
|
|
assert result.did_run(input1)
|
|
assert not result.did_run(input2)
|
|
assert result.did_run(mask)
|
|
assert result.did_run(lazy_mix)
|
|
|
|
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
|
self.clear_cache(client)
|
|
g = builder
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
|
|
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
|
g.node("SaveImage", images=lazy_mix.out(0))
|
|
|
|
result1 = client.run(g)
|
|
result2 = client.run(g)
|
|
for node_id, node in g.nodes.items():
|
|
assert result1.did_run(node), f"Node {node_id} didn't run"
|
|
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
|
|
|
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
|
self.clear_cache(client)
|
|
g = builder
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
|
|
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
|
g.node("SaveImage", images=lazy_mix.out(0))
|
|
|
|
result1 = client.run(g)
|
|
mask.inputs['value'] = 0.4
|
|
result2 = client.run(g)
|
|
for node_id, node in g.nodes.items():
|
|
assert result1.did_run(node), f"Node {node_id} didn't run"
|
|
assert not result2.did_run(input1), "Input1 should have been cached"
|
|
assert not result2.did_run(input2), "Input2 should have been cached"
|
|
assert result2.did_run(mask), "Mask should have been re-run"
|
|
assert result2.did_run(lazy_mix), "Lazy mix should have been re-run"
|
|
|
|
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
# Different size of the two images
|
|
input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1)
|
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
|
|
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
|
g.node("SaveImage", images=lazy_mix.out(0))
|
|
|
|
try:
|
|
client.run(g)
|
|
except Exception as e:
|
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
|
|
|
@pytest.mark.parametrize("test_value, expect_error", [
|
|
(5, True),
|
|
("foo", True),
|
|
(5.0, False),
|
|
])
|
|
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
|
g.node("SaveImage", images=validation1.out(0))
|
|
|
|
if expect_error:
|
|
with pytest.raises(urllib.error.HTTPError):
|
|
client.run(g)
|
|
else:
|
|
client.run(g)
|
|
|
|
@pytest.mark.parametrize("test_type, test_value", [
|
|
("StubInt", 5),
|
|
("StubFloat", 5.0)
|
|
])
|
|
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
stub = g.node(test_type, value=test_value)
|
|
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
|
g.node("SaveImage", images=validation1.out(0))
|
|
|
|
with pytest.raises(urllib.error.HTTPError):
|
|
client.run(g)
|
|
|
|
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
|
("StubInt", 5, True),
|
|
("StubFloat", 5.0, False)
|
|
])
|
|
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
stub = g.node(test_type, value=test_value)
|
|
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
|
g.node("SaveImage", images=validation2.out(0))
|
|
|
|
if expect_error:
|
|
with pytest.raises(urllib.error.HTTPError):
|
|
client.run(g)
|
|
else:
|
|
client.run(g)
|
|
|
|
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
|
("StubInt", 5, True),
|
|
("StubFloat", 5.0, False)
|
|
])
|
|
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
stub = g.node(test_type, value=test_value)
|
|
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
|
g.node("SaveImage", images=validation3.out(0))
|
|
|
|
if expect_error:
|
|
with pytest.raises(urllib.error.HTTPError):
|
|
client.run(g)
|
|
else:
|
|
client.run(g)
|
|
|
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
# Creating the nodes in this specific order previously caused a bug
|
|
save = g.node("SaveImage")
|
|
is_changed = g.node("TestCustomIsChanged", should_change=False)
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
|
|
save.set_input('images', is_changed.out(0))
|
|
is_changed.set_input('image', input1.out(0))
|
|
|
|
result1 = client.run(g)
|
|
result2 = client.run(g)
|
|
is_changed.set_input('should_change', True)
|
|
result3 = client.run(g)
|
|
result4 = client.run(g)
|
|
assert result1.did_run(is_changed), "is_changed should have been run"
|
|
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
|
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
|
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
|
|
|
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
|
self.clear_cache(client)
|
|
g = builder
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
|
input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0))
|
|
output = g.node("SaveImage", images=average.out(0))
|
|
|
|
result = client.run(g)
|
|
result_image = result.get_images(output)[0]
|
|
expected = 255 // 4
|
|
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
|
assert result.did_run(input1)
|
|
assert result.did_run(input2)
|
|
|
|
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
|
g = builder
|
|
iterations = 4
|
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
|
is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0))
|
|
for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0))
|
|
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2))
|
|
for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0))
|
|
output = g.node("SaveImage", images=for_close.out(0))
|
|
|
|
for iterations in range(1, 5):
|
|
for_open.set_input('remaining', iterations)
|
|
result = client.run(g)
|
|
result_image = result.get_images(output)[0]
|
|
expected = 255 // (2 ** iterations)
|
|
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
|
assert result.did_run(is_changed)
|