Fix tests, improve distributed worker health check, add torch compile options

This commit is contained in:
doctorpangloss 2024-09-13 18:10:11 -07:00
parent ffb4ed9cf2
commit 83b2f0174c
15 changed files with 226 additions and 147 deletions

View File

@ -1,25 +1,4 @@
import os
import yaml
import logging
def load_extra_path_config(yaml_path):
from . import folder_paths
from ..extra_config import load_extra_path_config
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info(f"Adding extra search path {x} ({full_path})")
folder_paths.add_model_folder_path(x, full_path)
return load_extra_path_config(yaml_path)

View File

@ -349,7 +349,7 @@ def invalidate_cache(folder_name):
_filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
"""
Example:
files = os.listdir(folder_paths.get_input_directory())

View File

@ -41,7 +41,14 @@ class DistributedPromptWorker:
self._health_check_site: Optional[web.TCPSite] = None
async def _health_check(self, request):
return web.Response(text="OK", content_type="text/plain")
if self._connection is None:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503)
is_healthy = await self._is_connection_healthy()
if is_healthy:
return web.Response(text="HEALTHY", status=200)
else:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503)
async def _start_health_check_server(self):
app = web.Application()
@ -85,9 +92,27 @@ class DistributedPromptWorker:
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def _is_connection_healthy(self):
if self._connection is None:
return False
return (
not self._connection.is_closed
and self._connection.connected.is_set()
and await self._check_connection_ready()
)
async def _check_connection_ready(self):
try:
await asyncio.wait_for(self._connection.ready(), timeout=1.0)
return True
except asyncio.TimeoutError:
return False
@tracer.start_as_current_span("Initialize Prompt Worker")
async def init(self):
await self._exit_stack.__aenter__()
await self._start_health_check_server()
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
@ -102,7 +127,6 @@ class DistributedPromptWorker:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item)
await self._start_health_check_server()
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()

View File

@ -1,9 +1,12 @@
import os
import yaml
import folder_paths
import logging
import os
import yaml
def load_extra_path_config(yaml_path):
from .cmd import folder_paths
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
for c in config:

View File

@ -1,16 +1,17 @@
import math
from scipy import integrate
import torch
from torch import nn
import torchsde
from scipy import integrate
from torch import nn
from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from . import utils
from .. import model_patcher
from .. import model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@ -274,6 +275,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
@ -281,6 +283,7 @@ def linear_multistep_coeff(order, t, i, j):
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@ -306,6 +309,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
self.h = h
self.b1 = (pcoeff + icoeff + dcoeff) / order
@ -552,17 +556,17 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
lambda_fn = lambda sigma: ((1 - sigma) / sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
@ -590,10 +594,11 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
@ -665,6 +670,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised
return x
@torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
@ -713,6 +719,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
h_last = h if h is not None else h_last
return x
@torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE."""
@ -766,6 +773,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
h_1, h_2 = h, h_1
return x
@torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
@ -775,6 +783,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
@ -784,6 +793,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
@ -804,6 +814,7 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -823,6 +834,7 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
@ -839,7 +851,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""
@ -893,12 +904,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w2 = sigmas[i + 1] / w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
@ -922,8 +932,8 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
# under Apache 2 license
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
@ -943,27 +953,28 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
order = min(max_order, i + 1)
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
elif order == 3: # Use two history points.
elif order == 3: # Use two history points.
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
elif order == 4: # Use three history points.
elif order == 4: # Use three history points.
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur
else:
buffer_model.append(d_cur)
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
# under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
@ -984,32 +995,32 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
if order == 1: # First Euler step.
order = min(max_order, i + 1)
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
elif order == 2: # Use one history point.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_1 = (t_cur - t_steps[i - 1])
coeff1 = (2 + (h_n / h_n_1)) / 2
coeff2 = -(h_n / h_n_1) / 2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
elif order == 3: # Use two history points.
elif order == 3: # Use two history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
h_n_1 = (t_cur - t_steps[i - 1])
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
coeff3 = temp * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
elif order == 4: # Use three history points.
elif order == 4: # Use three history points.
h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1])
h_n_2 = (t_steps[i-1] - t_steps[i-2])
h_n_3 = (t_steps[i-2] - t_steps[i-3])
h_n_1 = (t_cur - t_steps[i - 1])
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
h_n_3 = (t_steps[i - 2] - t_steps[i - 3])
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
@ -1018,15 +1029,16 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
# under Apache 2 license
@torch.no_grad()
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
extra_args = {} if extra_args is None else extra_args
@ -1050,36 +1062,38 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1)
order = min(max_order, i + 1)
if t_next <= 0:
order = 1
if order == 1: # First Euler step.
if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point.
elif order == 2: # Use one history point.
coeff_cur, coeff_prev1 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
elif order == 3: # Use two history points.
elif order == 3: # Use two history points.
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
elif order == 4: # Use three history points.
elif order == 4: # Use three history points.
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
if len(buffer_model) == max_order - 1:
for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1]
buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur.detach()
else:
buffer_model.append(d_cur.detach())
return x_next
@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
@ -1099,6 +1113,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised + d * sigmas[i + 1]
return x
@torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
@ -1106,6 +1121,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
@ -1126,6 +1142,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
@ -1133,12 +1151,13 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]
model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()

View File

@ -0,0 +1,52 @@
import logging
import torch
from comfy.model_patcher import ModelPatcher
DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
"optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": ("STRING", {"default": "inductor"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
"fullgraph": fullgraph,
"dynamic": dynamic,
"backend": backend
}
if isinstance(model, ModelPatcher):
m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
return (m,)
elif isinstance(model, torch.nn.Module):
return torch.compile(model=model, **compile_kwargs),
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -1,21 +0,0 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -118,7 +118,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
for _ in range(max_retries):
try:
async with session.get(url, timeout=1) as response:
if response.status == 200 and await response.text() == "OK":
if response.status == 200:
return True
except Exception as exc_info:
pass

View File

@ -59,7 +59,7 @@ class _ProgressHandler(ServerStub):
self.tuples.append((event, data, sid))
class Client:
class ComfyClient:
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
self.embedded_client = embedded_client
self.progress_handler = progress_handler
@ -105,7 +105,7 @@ class TestExecution:
(0,),
(100,),
])
async def client(self, request) -> Client:
async def client(self, request) -> ComfyClient:
from comfy.cmd.execution import nodes
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
@ -115,13 +115,13 @@ class TestExecution:
configuration.cache_lru = lru_size
progress_handler = _ProgressHandler()
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
yield Client(embedded_client, progress_handler)
yield ComfyClient(embedded_client, progress_handler)
@fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
async def test_lazy_input(self, client: Client, builder: GraphBuilder):
async def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -138,7 +138,7 @@ class TestExecution:
assert result.did_run(mask)
assert result.did_run(lazy_mix)
async def test_full_cache(self, client: Client, builder: GraphBuilder):
async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@ -152,7 +152,7 @@ class TestExecution:
for node_id, node in g.nodes.items():
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
async def test_partial_cache(self, client: Client, builder: GraphBuilder):
async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@ -167,7 +167,7 @@ class TestExecution:
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
async def test_error(self, client: Client, builder: GraphBuilder):
async def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Different size of the two images
@ -188,7 +188,7 @@ class TestExecution:
("foo", True),
(5.0, False),
])
async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder):
async def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0))
@ -203,7 +203,7 @@ class TestExecution:
("StubInt", 5),
("StubFloat", 5.0)
])
async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder):
async def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
@ -216,7 +216,7 @@ class TestExecution:
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
@ -232,7 +232,7 @@ class TestExecution:
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
@ -248,7 +248,7 @@ class TestExecution:
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
@ -260,7 +260,7 @@ class TestExecution:
else:
await client.run(g)
async def test_cycle_error(self, client: Client, builder: GraphBuilder):
async def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -274,7 +274,7 @@ class TestExecution:
with pytest.raises(ValueError):
await client.run(g)
async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder):
async def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -289,7 +289,7 @@ class TestExecution:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
async def test_custom_is_changed(self, client: Client, builder: GraphBuilder):
async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
save = g.node("SaveImage")
@ -309,7 +309,7 @@ class TestExecution:
assert result3.did_run(is_changed), "is_changed should have been re-run"
assert result4.did_run(is_changed), "is_changed should not have been cached"
async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder):
async def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -323,7 +323,7 @@ class TestExecution:
expected = 255 // 4
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
async def test_for_loop(self, client: Client, builder: GraphBuilder):
async def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
g = builder
iterations = 4
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -342,7 +342,7 @@ class TestExecution:
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
assert result.did_run(is_changed)
async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder):
async def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
g = builder
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
@ -361,7 +361,7 @@ class TestExecution:
for i in range(3):
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder):
async def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
g = builder
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
@ -378,7 +378,7 @@ class TestExecution:
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
async def test_missing_node_error(self, client: Client, builder: GraphBuilder):
async def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
@ -397,7 +397,7 @@ class TestExecution:
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
await client.run(g)
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
async def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -410,9 +410,8 @@ class TestExecution:
assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED`
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder):
async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
@ -433,7 +432,7 @@ class TestExecution:
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
async def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -445,11 +444,11 @@ class TestExecution:
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g)
result = await client.run(g)
assert result.did_run(output), "The execution should have run"
images = result.get_images(output)
assert len(images) == 2, "Should have 2 images"

View File

@ -1,11 +1,13 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down
import pytest
import os
import tempfile
from unittest.mock import patch
import folder_paths
import pytest
from comfy.cmd import folder_paths
@pytest.fixture
def temp_dir():
@ -19,21 +21,25 @@ def test_get_directory_by_type():
assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath():
default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
@ -43,12 +49,14 @@ def test_recursive_search(temp_dir):
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
@ -56,6 +64,7 @@ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
@ -63,4 +72,4 @@ def test_get_save_image_path(temp_dir):
assert filename == "test"
assert counter == 1
assert subfolder == ""
assert filename_prefix == "test"
assert filename_prefix == "test"

View File

@ -1,13 +1,16 @@
import pytest
import os
import tempfile
from folder_paths import filter_files_content_types
import pytest
from comfy.cmd.folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}
@ -49,4 +52,4 @@ def test_handles_no_extension():
def test_handles_no_files():
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

View File

@ -1,10 +1,12 @@
import pytest
import yaml
import os
from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config
import folder_paths
import pytest
import yaml
from comfy.cmd import folder_paths
from comfy.extra_config import load_extra_path_config
@pytest.fixture
def mock_yaml_content():
@ -15,10 +17,12 @@ def mock_yaml_content():
}
}
@pytest.fixture
def mock_expanded_home():
return '/home/user'
@pytest.fixture
def yaml_config_with_appdata():
return """
@ -27,40 +31,47 @@ def yaml_config_with_appdata():
checkpoints: 'models/checkpoints'
"""
@pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture
def mock_expandvars_appdata():
mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock
@pytest.fixture
def mock_add_model_folder_path():
return Mock()
@pytest.fixture
def mock_expanduser(mock_expanded_home):
def _expanduser(path):
if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:])
return path
return _expanduser
@pytest.fixture
def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expanduser,
mock_yaml_safe_load,
mock_expanded_home
):
# Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
@ -75,7 +86,7 @@ def test_load_extra_model_paths_expands_userpath(
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
@ -86,14 +97,15 @@ def test_load_extra_model_paths_expands_userpath(
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata(
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
mock_file,
monkeypatch,
mock_add_model_folder_path,
mock_expandvars_appdata,
yaml_config_with_appdata,
mock_yaml_content_appdata
):
# Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata
@ -115,7 +127,7 @@ def test_load_extra_model_paths_expands_appdata(
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call