Separate server fixture

This commit is contained in:
enzymezoo-code 2023-09-04 21:24:29 -05:00
parent 5b4aaa6829
commit d7533bfcd1

View File

@ -127,7 +127,6 @@ DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
prompt_list = [
'a painting of a cat',
'a photo of a toad',
]
sampler_list = KSampler.SAMPLERS[0:2]
scheduler_list = [KSampler.SCHEDULERS[0]]
@ -137,13 +136,9 @@ scheduler_list = [KSampler.SCHEDULERS[0]]
@pytest.mark.parametrize("scheduler", scheduler_list)
@pytest.mark.parametrize("prompt", prompt_list)
class TestInference:
# Initialize pipeline
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
# The "graph" is the default graph
@fixture(scope="class", params=comfy_graph_list, autouse=True)
def _client_graph(self, request, args_pytest) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param
# Initialize server
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
# Start server
p = subprocess.Popen([
'python','main.py',
@ -151,7 +146,11 @@ class TestInference:
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
])
yield
p.kill()
torch.cuda.empty_cache()
def start_client(self, listen:str, port:int):
# Start client
comfy_client = ComfyClient()
# Connect to server (with retries)
@ -159,20 +158,29 @@ class TestInference:
for i in range(n_tries):
time.sleep(4)
try:
comfy_client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
comfy_client.connect(listen=listen, port=port)
except ConnectionRefusedError as e:
print(e)
print(f"({i+1}/{n_tries}) Retrying...")
else:
break
return comfy_client
# warm up pipeline
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
# The "graph" is the default graph
@fixture(scope="class", params=comfy_graph_list, autouse=True)
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
comfy_graph = request.param
# Start client
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
# Warm up pipeline
comfy_client.get_images(graph=comfy_graph.graph, save=False)
yield comfy_client, comfy_graph
del comfy_client
del comfy_graph
p.kill()
torch.cuda.empty_cache()
@fixture
@ -180,7 +188,7 @@ class TestInference:
client = _client_graph[0]
yield client
# method-scoped fixture for graph to avoid mutating the graph
# function-scoped fixture for graph to avoid mutating the graph
@fixture
def comfy_graph(self, _client_graph):
graph = deepcopy(_client_graph[1])