From 83b2f0174c759148355e8c84e1899c12f9f59d58 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 13 Sep 2024 18:10:11 -0700 Subject: [PATCH] Fix tests, improve distributed worker health check, add torch compile options --- comfy/cmd/extra_model_paths.py | 25 +--- comfy/cmd/folder_paths.py | 2 +- .../distributed/distributed_prompt_worker.py | 28 ++++- {utils => comfy}/extra_config.py | 9 +- comfy/k_diffusion/sampling.py | 107 +++++++++++------- comfy_extras/nodes/nodes_torch_compile.py | 52 +++++++++ comfy_extras/nodes_torch_compile.py | 21 ---- tests/distributed/test_distributed_queue.py | 2 +- tests/inference/test_execution.py | 51 ++++----- .../unit/comfy_test}/__init__.py | 0 .../unit}/comfy_test/folder_path_test.py | 15 ++- .../unit/folder_paths_test}/__init__.py | 0 .../filter_by_content_types_test.py | 13 ++- tests/unit/utils/__init__.py | 0 .../unit}/utils/extra_config_test.py | 48 +++++--- 15 files changed, 226 insertions(+), 147 deletions(-) rename {utils => comfy}/extra_config.py (95%) create mode 100644 comfy_extras/nodes/nodes_torch_compile.py delete mode 100644 comfy_extras/nodes_torch_compile.py rename {tests-unit/folder_paths_test => tests/unit/comfy_test}/__init__.py (100%) rename {tests-unit => tests/unit}/comfy_test/folder_path_test.py (97%) rename {utils => tests/unit/folder_paths_test}/__init__.py (100%) rename {tests-unit => tests/unit}/folder_paths_test/filter_by_content_types_test.py (92%) create mode 100644 tests/unit/utils/__init__.py rename {tests-unit => tests/unit}/utils/extra_config_test.py (89%) diff --git a/comfy/cmd/extra_model_paths.py b/comfy/cmd/extra_model_paths.py index 524f53e86..564e6df99 100644 --- a/comfy/cmd/extra_model_paths.py +++ b/comfy/cmd/extra_model_paths.py @@ -1,25 +1,4 @@ -import os -import yaml -import logging - def load_extra_path_config(yaml_path): - from . import folder_paths + from ..extra_config import load_extra_path_config - with open(yaml_path, 'r') as stream: - config = yaml.safe_load(stream) - for c in config: - conf = config[c] - if conf is None: - continue - base_path = None - if "base_path" in conf: - base_path = conf.pop("base_path") - for x in conf: - for y in conf[x].split("\n"): - if len(y) == 0: - continue - full_path = y - if base_path is not None: - full_path = os.path.join(base_path, full_path) - logging.info(f"Adding extra search path {x} ({full_path})") - folder_paths.add_model_folder_path(x, full_path) + return load_extra_path_config(yaml_path) diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index e5ca1ec90..561bd6dde 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -349,7 +349,7 @@ def invalidate_cache(folder_name): _filename_list_cache.pop(folder_name, None) -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 76a48911a..e66188064 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -41,7 +41,14 @@ class DistributedPromptWorker: self._health_check_site: Optional[web.TCPSite] = None async def _health_check(self, request): - return web.Response(text="OK", content_type="text/plain") + if self._connection is None: + return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503) + + is_healthy = await self._is_connection_healthy() + if is_healthy: + return web.Response(text="HEALTHY", status=200) + else: + return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503) async def _start_health_check_server(self): app = web.Application() @@ -85,9 +92,27 @@ class DistributedPromptWorker: await self.on_did_complete_work_item(request_obj, reply) return asdict(reply) + async def _is_connection_healthy(self): + if self._connection is None: + return False + + return ( + not self._connection.is_closed + and self._connection.connected.is_set() + and await self._check_connection_ready() + ) + + async def _check_connection_ready(self): + try: + await asyncio.wait_for(self._connection.ready(), timeout=1.0) + return True + except asyncio.TimeoutError: + return False + @tracer.start_as_current_span("Initialize Prompt Worker") async def init(self): await self._exit_stack.__aenter__() + await self._start_health_check_server() try: self._connection = await connect_robust(self._connection_uri, loop=self._loop) except AMQPConnectionError as connection_error: @@ -102,7 +127,6 @@ class DistributedPromptWorker: await self._exit_stack.enter_async_context(self._embedded_comfy_client) await self._rpc.register(self._queue_name, self._do_work_item) - await self._start_health_check_server() async def __aenter__(self) -> "DistributedPromptWorker": await self.init() diff --git a/utils/extra_config.py b/comfy/extra_config.py similarity index 95% rename from utils/extra_config.py rename to comfy/extra_config.py index 7ac3650ac..83759e86b 100644 --- a/utils/extra_config.py +++ b/comfy/extra_config.py @@ -1,9 +1,12 @@ -import os -import yaml -import folder_paths import logging +import os + +import yaml + def load_extra_path_config(yaml_path): + from .cmd import folder_paths + with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) for c in config: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index a597edbf2..dce61495b 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,16 +1,17 @@ import math -from scipy import integrate import torch -from torch import nn import torchsde +from scipy import integrate +from torch import nn from tqdm.auto import trange, tqdm -from . import utils from . import deis +from . import utils from .. import model_patcher from .. import model_sampling + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) @@ -274,6 +275,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis def linear_multistep_coeff(order, t, i, j): if order - 1 > i: raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): prod = 1. for k in range(order): @@ -281,6 +283,7 @@ def linear_multistep_coeff(order, t, i, j): continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] @@ -306,6 +309,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): self.h = h self.b1 = (pcoeff + icoeff + dcoeff) / order @@ -552,17 +556,17 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 - lambda_fn = lambda sigma: ((1-sigma)/sigma).log() + lambda_fn = lambda sigma: ((1 - sigma) / sigma).log() # logged_x = x.unsqueeze(0) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5 # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) @@ -590,10 +594,11 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) # Noise addition if sigmas[i + 1] > 0 and eta > 0: - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) return x + @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" @@ -665,6 +670,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No old_denoised = denoised return x + @torch.no_grad() def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): """DPM-Solver++(2M) SDE.""" @@ -713,6 +719,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h_last = h if h is not None else h_last return x + @torch.no_grad() def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """DPM-Solver++(3M) SDE.""" @@ -766,6 +773,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h_1, h_2 = h, h_1 return x + @torch.no_grad() def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: @@ -775,6 +783,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + @torch.no_grad() def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: @@ -784,6 +793,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + @torch.no_grad() def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: @@ -804,6 +814,7 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) return mu + def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -823,6 +834,7 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) + @torch.no_grad() def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): extra_args = {} if extra_args is None else extra_args @@ -839,7 +851,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n return x - @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """ @@ -893,12 +904,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non d_2 = to_d(x_2, sigmas[i + 1], denoised_2) w = 2 * sigmas[0] - w2 = sigmas[i+1]/w + w2 = sigmas[i + 1] / w w1 = 1 - w2 d_prime = d * w1 + d_2 * w2 - x = x + d_prime * dt else: @@ -922,8 +932,8 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non return x -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -943,27 +953,28 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if order == 1: # First Euler step. + order = min(max_order, i + 1) + if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 - elif order == 3: # Use two history points. + elif order == 3: # Use two history points. x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 - elif order == 4: # Use three history points. + elif order == 4: # Use three history points. x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur else: buffer_model.append(d_cur) return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license + +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -984,32 +995,32 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if order == 1: # First Euler step. + order = min(max_order, i + 1) + if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) + h_n_1 = (t_cur - t_steps[i - 1]) coeff1 = (2 + (h_n / h_n_1)) / 2 coeff2 = -(h_n / h_n_1) / 2 x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1]) - elif order == 3: # Use two history points. + elif order == 3: # Use two history points. h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) + h_n_1 = (t_cur - t_steps[i - 1]) + h_n_2 = (t_steps[i - 1] - t_steps[i - 2]) temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 coeff1 = (2 + (h_n / h_n_1)) / 2 + temp coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp coeff3 = temp * h_n_1 / h_n_2 x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2]) - elif order == 4: # Use three history points. + elif order == 4: # Use three history points. h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) - h_n_3 = (t_steps[i-2] - t_steps[i-3]) + h_n_1 = (t_cur - t_steps[i - 1]) + h_n_2 = (t_steps[i - 1] - t_steps[i - 2]) + h_n_3 = (t_steps[i - 2] - t_steps[i - 3]) temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \ - * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) + * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2 coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2 coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2 @@ -1018,15 +1029,16 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license + +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license @torch.no_grad() def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'): extra_args = {} if extra_args is None else extra_args @@ -1050,36 +1062,38 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) + order = min(max_order, i + 1) if t_next <= 0: order = 1 - if order == 1: # First Euler step. + if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. coeff_cur, coeff_prev1 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] - elif order == 3: # Use two history points. + elif order == 3: # Use two history points. coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] - elif order == 4: # Use three history points. + elif order == 4: # Use three history points. coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3] if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) return x_next + @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): extra_args = {} if extra_args is None else extra_args temp = [0] + def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] @@ -1099,6 +1113,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl x = denoised + d * sigmas[i + 1] return x + @torch.no_grad() def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" @@ -1106,6 +1121,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler temp = [0] + def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] @@ -1126,6 +1142,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + + @torch.no_grad() def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" @@ -1133,12 +1151,13 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler temp = [0] + def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py new file mode 100644 index 000000000..802856829 --- /dev/null +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -0,0 +1,52 @@ +import logging + +import torch + +from comfy.model_patcher import ModelPatcher + +DIFFUSION_MODEL = "diffusion_model" + + +class TorchCompileModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + }, + "optional": { + "object_patch": ("STRING", {"default": DIFFUSION_MODEL}), + "fullgraph": ("BOOLEAN", {"default": False}), + "dynamic": ("BOOLEAN", {"default": False}), + "backend": ("STRING", {"default": "inductor"}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"): + if object_patch is None: + object_patch = DIFFUSION_MODEL + compile_kwargs = { + "fullgraph": fullgraph, + "dynamic": dynamic, + "backend": backend + } + if isinstance(model, ModelPatcher): + m = model.clone() + m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) + return (m,) + elif isinstance(model, torch.nn.Module): + return torch.compile(model=model, **compile_kwargs), + else: + logging.warning("Encountered a model that cannot be compiled") + return model, + + +NODE_CLASS_MAPPINGS = { + "TorchCompileModel": TorchCompileModel, +} diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py deleted file mode 100644 index 1d914fa93..000000000 --- a/comfy_extras/nodes_torch_compile.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -class TorchCompileModel: - @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "_for_testing" - EXPERIMENTAL = True - - def patch(self, model): - m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) - return (m, ) - -NODE_CLASS_MAPPINGS = { - "TorchCompileModel": TorchCompileModel, -} diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 4fb228362..3e5084143 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -118,7 +118,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0) for _ in range(max_retries): try: async with session.get(url, timeout=1) as response: - if response.status == 200 and await response.text() == "OK": + if response.status == 200: return True except Exception as exc_info: pass diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 4a7d0ddaf..a04b2c55c 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -59,7 +59,7 @@ class _ProgressHandler(ServerStub): self.tuples.append((event, data, sid)) -class Client: +class ComfyClient: def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler): self.embedded_client = embedded_client self.progress_handler = progress_handler @@ -105,7 +105,7 @@ class TestExecution: (0,), (100,), ]) - async def client(self, request) -> Client: + async def client(self, request) -> ComfyClient: from comfy.cmd.execution import nodes from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS @@ -115,13 +115,13 @@ class TestExecution: configuration.cache_lru = lru_size progress_handler = _ProgressHandler() async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: - yield Client(embedded_client, progress_handler) + yield ComfyClient(embedded_client, progress_handler) @fixture def builder(self, request): yield GraphBuilder(prefix=request.node.name) - async def test_lazy_input(self, client: Client, builder: GraphBuilder): + async def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -138,7 +138,7 @@ class TestExecution: assert result.did_run(mask) assert result.did_run(lazy_mix) - async def test_full_cache(self, client: Client, builder: GraphBuilder): + async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -152,7 +152,7 @@ class TestExecution: for node_id, node in g.nodes.items(): assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" - async def test_partial_cache(self, client: Client, builder: GraphBuilder): + async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -167,7 +167,7 @@ class TestExecution: assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached" - async def test_error(self, client: Client, builder: GraphBuilder): + async def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Different size of the two images @@ -188,7 +188,7 @@ class TestExecution: ("foo", True), (5.0, False), ]) - async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder): + async def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): g = builder validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) g.node("SaveImage", images=validation1.out(0)) @@ -203,7 +203,7 @@ class TestExecution: ("StubInt", 5), ("StubFloat", 5.0) ]) - async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder): + async def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): g = builder stub = g.node(test_type, value=test_value) validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) @@ -216,7 +216,7 @@ class TestExecution: ("StubInt", 5, True), ("StubFloat", 5.0, False) ]) - async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): g = builder stub = g.node(test_type, value=test_value) validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) @@ -232,7 +232,7 @@ class TestExecution: ("StubInt", 5, True), ("StubFloat", 5.0, False) ]) - async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): g = builder stub = g.node(test_type, value=test_value) validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) @@ -248,7 +248,7 @@ class TestExecution: ("StubInt", 5, True), ("StubFloat", 5.0, False) ]) - async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): g = builder stub = g.node(test_type, value=test_value) validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) @@ -260,7 +260,7 @@ class TestExecution: else: await client.run(g) - async def test_cycle_error(self, client: Client, builder: GraphBuilder): + async def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -274,7 +274,7 @@ class TestExecution: with pytest.raises(ValueError): await client.run(g) - async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder): + async def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -289,7 +289,7 @@ class TestExecution: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" - async def test_custom_is_changed(self, client: Client, builder: GraphBuilder): + async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -309,7 +309,7 @@ class TestExecution: assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" - async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder): + async def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -323,7 +323,7 @@ class TestExecution: expected = 255 // 4 assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" - async def test_for_loop(self, client: Client, builder: GraphBuilder): + async def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): g = builder iterations = 4 input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -342,7 +342,7 @@ class TestExecution: assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" assert result.did_run(is_changed) - async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder): + async def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder): g = builder val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3) mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0)) @@ -361,7 +361,7 @@ class TestExecution: for i in range(3): assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" - async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder): + async def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder): g = builder val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0) mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1) @@ -378,7 +378,7 @@ class TestExecution: assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5" assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0" - async def test_missing_node_error(self, client: Client, builder: GraphBuilder): + async def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) @@ -397,7 +397,7 @@ class TestExecution: input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) await client.run(g) - async def test_output_reuse(self, client: Client, builder: GraphBuilder): + async def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -410,9 +410,8 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` - async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder): + async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -433,7 +432,7 @@ class TestExecution: # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, # only that one entry in the list is blocked. - def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder): + async def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder): g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -445,11 +444,11 @@ class TestExecution: int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0)) compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==") blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) - + list_output = g.node("TestMakeListNode", value1=blocker.out(0)) output = g.node("PreviewImage", images=list_output.out(0)) - result = client.run(g) + result = await client.run(g) assert result.did_run(output), "The execution should have run" images = result.get_images(output) assert len(images) == 2, "Should have 2 images" diff --git a/tests-unit/folder_paths_test/__init__.py b/tests/unit/comfy_test/__init__.py similarity index 100% rename from tests-unit/folder_paths_test/__init__.py rename to tests/unit/comfy_test/__init__.py diff --git a/tests-unit/comfy_test/folder_path_test.py b/tests/unit/comfy_test/folder_path_test.py similarity index 97% rename from tests-unit/comfy_test/folder_path_test.py rename to tests/unit/comfy_test/folder_path_test.py index 0bbec593b..82c2930e8 100644 --- a/tests-unit/comfy_test/folder_path_test.py +++ b/tests/unit/comfy_test/folder_path_test.py @@ -1,11 +1,13 @@ ### 🗻 This file is created through the spirit of Mount Fuji at its peak # TODO(yoland): clean up this after I get back down -import pytest import os import tempfile from unittest.mock import patch -import folder_paths +import pytest + +from comfy.cmd import folder_paths + @pytest.fixture def temp_dir(): @@ -19,21 +21,25 @@ def test_get_directory_by_type(): assert folder_paths.get_directory_by_type("output") == test_dir assert folder_paths.get_directory_by_type("invalid") is None + def test_annotated_filepath(): assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None) assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory()) assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory()) assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory()) + def test_get_annotated_filepath(): default_dir = "/default/dir" assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt") assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt") + def test_add_model_folder_path(): folder_paths.add_model_folder_path("test_folder", "/test/path") assert "/test/path" in folder_paths.get_folder_paths("test_folder") + def test_recursive_search(temp_dir): os.makedirs(os.path.join(temp_dir, "subdir")) open(os.path.join(temp_dir, "file1.txt"), "w").close() @@ -43,12 +49,14 @@ def test_recursive_search(temp_dir): assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")} assert len(dirs) == 2 # temp_dir and subdir + def test_filter_files_extensions(): files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"] assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"] assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"] assert folder_paths.filter_files_extensions(files, []) == files + @patch("folder_paths.recursive_search") @patch("folder_paths.folder_names_and_paths") def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): @@ -56,6 +64,7 @@ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {}) assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] + def test_get_save_image_path(temp_dir): with patch("folder_paths.output_directory", temp_dir): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100) @@ -63,4 +72,4 @@ def test_get_save_image_path(temp_dir): assert filename == "test" assert counter == 1 assert subfolder == "" - assert filename_prefix == "test" \ No newline at end of file + assert filename_prefix == "test" diff --git a/utils/__init__.py b/tests/unit/folder_paths_test/__init__.py similarity index 100% rename from utils/__init__.py rename to tests/unit/folder_paths_test/__init__.py diff --git a/tests-unit/folder_paths_test/filter_by_content_types_test.py b/tests/unit/folder_paths_test/filter_by_content_types_test.py similarity index 92% rename from tests-unit/folder_paths_test/filter_by_content_types_test.py rename to tests/unit/folder_paths_test/filter_by_content_types_test.py index 5941bfa94..e0b2876fa 100644 --- a/tests-unit/folder_paths_test/filter_by_content_types_test.py +++ b/tests/unit/folder_paths_test/filter_by_content_types_test.py @@ -1,13 +1,16 @@ -import pytest import os import tempfile -from folder_paths import filter_files_content_types + +import pytest + +from comfy.cmd.folder_paths import filter_files_content_types + @pytest.fixture(scope="module") def file_extensions(): return { - 'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'], - 'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'], + 'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'], + 'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'], 'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] } @@ -49,4 +52,4 @@ def test_handles_no_extension(): def test_handles_no_files(): files = [] - assert filter_files_content_types(files, ["image", "audio", "video"]) == [] \ No newline at end of file + assert filter_files_content_types(files, ["image", "audio", "video"]) == [] diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/utils/extra_config_test.py b/tests/unit/utils/extra_config_test.py similarity index 89% rename from tests-unit/utils/extra_config_test.py rename to tests/unit/utils/extra_config_test.py index 06349560d..b678cc16f 100644 --- a/tests-unit/utils/extra_config_test.py +++ b/tests/unit/utils/extra_config_test.py @@ -1,10 +1,12 @@ -import pytest -import yaml import os from unittest.mock import Mock, patch, mock_open -from utils.extra_config import load_extra_path_config -import folder_paths +import pytest +import yaml + +from comfy.cmd import folder_paths +from comfy.extra_config import load_extra_path_config + @pytest.fixture def mock_yaml_content(): @@ -15,10 +17,12 @@ def mock_yaml_content(): } } + @pytest.fixture def mock_expanded_home(): return '/home/user' + @pytest.fixture def yaml_config_with_appdata(): return """ @@ -27,40 +31,47 @@ def yaml_config_with_appdata(): checkpoints: 'models/checkpoints' """ + @pytest.fixture def mock_yaml_content_appdata(yaml_config_with_appdata): return yaml.safe_load(yaml_config_with_appdata) + @pytest.fixture def mock_expandvars_appdata(): mock = Mock() mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming') return mock + @pytest.fixture def mock_add_model_folder_path(): return Mock() + @pytest.fixture def mock_expanduser(mock_expanded_home): def _expanduser(path): if path.startswith('~/'): return os.path.join(mock_expanded_home, path[2:]) return path + return _expanduser + @pytest.fixture def mock_yaml_safe_load(mock_yaml_content): return Mock(return_value=mock_yaml_content) + @patch('builtins.open', new_callable=mock_open, read_data="dummy file content") def test_load_extra_model_paths_expands_userpath( - mock_file, - monkeypatch, - mock_add_model_folder_path, - mock_expanduser, - mock_yaml_safe_load, - mock_expanded_home + mock_file, + monkeypatch, + mock_add_model_folder_path, + mock_expanduser, + mock_yaml_safe_load, + mock_expanded_home ): # Attach mocks used by load_extra_path_config monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path) @@ -75,7 +86,7 @@ def test_load_extra_model_paths_expands_userpath( ] assert mock_add_model_folder_path.call_count == len(expected_calls) - + # Check if add_model_folder_path was called with the correct arguments for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): assert actual_call.args == expected_call @@ -86,14 +97,15 @@ def test_load_extra_model_paths_expands_userpath( # Check if open was called with the correct file path mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') + @patch('builtins.open', new_callable=mock_open) def test_load_extra_model_paths_expands_appdata( - mock_file, - monkeypatch, - mock_add_model_folder_path, - mock_expandvars_appdata, - yaml_config_with_appdata, - mock_yaml_content_appdata + mock_file, + monkeypatch, + mock_add_model_folder_path, + mock_expandvars_appdata, + yaml_config_with_appdata, + mock_yaml_content_appdata ): # Set the mock_file to return yaml with appdata as a variable mock_file.return_value.read.return_value = yaml_config_with_appdata @@ -115,7 +127,7 @@ def test_load_extra_model_paths_expands_appdata( ] assert mock_add_model_folder_path.call_count == len(expected_calls) - + # Check the base path variable was expanded for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): assert actual_call.args == expected_call