From d7533bfcd1770e6ea5ed04a6963b948d8cc68ad6 Mon Sep 17 00:00:00 2001 From: enzymezoo-code Date: Mon, 4 Sep 2023 21:24:29 -0500 Subject: [PATCH] Separate server fixture --- tests/inference/test_inference.py | 34 +++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 52ed60b0e..4a6c19296 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -127,7 +127,6 @@ DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) comfy_graph_list = [DEFAULT_COMFY_GRAPH] prompt_list = [ 'a painting of a cat', - 'a photo of a toad', ] sampler_list = KSampler.SAMPLERS[0:2] scheduler_list = [KSampler.SCHEDULERS[0]] @@ -137,13 +136,9 @@ scheduler_list = [KSampler.SCHEDULERS[0]] @pytest.mark.parametrize("scheduler", scheduler_list) @pytest.mark.parametrize("prompt", prompt_list) class TestInference: - # Initialize pipeline - # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server - # The "graph" is the default graph - @fixture(scope="class", params=comfy_graph_list, autouse=True) - def _client_graph(self, request, args_pytest) -> (ComfyClient, ComfyGraph): - comfy_graph = request.param - + # Initialize server + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): # Start server p = subprocess.Popen([ 'python','main.py', @@ -151,7 +146,11 @@ class TestInference: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), ]) - + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): # Start client comfy_client = ComfyClient() # Connect to server (with retries) @@ -159,20 +158,29 @@ class TestInference: for i in range(n_tries): time.sleep(4) try: - comfy_client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + comfy_client.connect(listen=listen, port=port) except ConnectionRefusedError as e: print(e) print(f"({i+1}/{n_tries}) Retrying...") else: break + return comfy_client - # warm up pipeline + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, autouse=True) + def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start client + comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) + + # Warm up pipeline comfy_client.get_images(graph=comfy_graph.graph, save=False) yield comfy_client, comfy_graph del comfy_client del comfy_graph - p.kill() torch.cuda.empty_cache() @fixture @@ -180,7 +188,7 @@ class TestInference: client = _client_graph[0] yield client - # method-scoped fixture for graph to avoid mutating the graph + # function-scoped fixture for graph to avoid mutating the graph @fixture def comfy_graph(self, _client_graph): graph = deepcopy(_client_graph[1])