Fix tests, improve distributed worker health check, add torch compile options

This commit is contained in:
doctorpangloss 2024-09-13 18:10:11 -07:00
parent ffb4ed9cf2
commit 83b2f0174c
15 changed files with 226 additions and 147 deletions

View File

@ -1,25 +1,4 @@
import os
import yaml
import logging
def load_extra_path_config(yaml_path): def load_extra_path_config(yaml_path):
from . import folder_paths from ..extra_config import load_extra_path_config
with open(yaml_path, 'r') as stream: return load_extra_path_config(yaml_path)
config = yaml.safe_load(stream)
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
continue
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info(f"Adding extra search path {x} ({full_path})")
folder_paths.add_model_folder_path(x, full_path)

View File

@ -349,7 +349,7 @@ def invalidate_cache(folder_name):
_filename_list_cache.pop(folder_name, None) _filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
""" """
Example: Example:
files = os.listdir(folder_paths.get_input_directory()) files = os.listdir(folder_paths.get_input_directory())

View File

@ -41,7 +41,14 @@ class DistributedPromptWorker:
self._health_check_site: Optional[web.TCPSite] = None self._health_check_site: Optional[web.TCPSite] = None
async def _health_check(self, request): async def _health_check(self, request):
return web.Response(text="OK", content_type="text/plain") if self._connection is None:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503)
is_healthy = await self._is_connection_healthy()
if is_healthy:
return web.Response(text="HEALTHY", status=200)
else:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503)
async def _start_health_check_server(self): async def _start_health_check_server(self):
app = web.Application() app = web.Application()
@ -85,9 +92,27 @@ class DistributedPromptWorker:
await self.on_did_complete_work_item(request_obj, reply) await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply) return asdict(reply)
async def _is_connection_healthy(self):
if self._connection is None:
return False
return (
not self._connection.is_closed
and self._connection.connected.is_set()
and await self._check_connection_ready()
)
async def _check_connection_ready(self):
try:
await asyncio.wait_for(self._connection.ready(), timeout=1.0)
return True
except asyncio.TimeoutError:
return False
@tracer.start_as_current_span("Initialize Prompt Worker") @tracer.start_as_current_span("Initialize Prompt Worker")
async def init(self): async def init(self):
await self._exit_stack.__aenter__() await self._exit_stack.__aenter__()
await self._start_health_check_server()
try: try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop) self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error: except AMQPConnectionError as connection_error:
@ -102,7 +127,6 @@ class DistributedPromptWorker:
await self._exit_stack.enter_async_context(self._embedded_comfy_client) await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item) await self._rpc.register(self._queue_name, self._do_work_item)
await self._start_health_check_server()
async def __aenter__(self) -> "DistributedPromptWorker": async def __aenter__(self) -> "DistributedPromptWorker":
await self.init() await self.init()

View File

@ -1,9 +1,12 @@
import os
import yaml
import folder_paths
import logging import logging
import os
import yaml
def load_extra_path_config(yaml_path): def load_extra_path_config(yaml_path):
from .cmd import folder_paths
with open(yaml_path, 'r') as stream: with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
for c in config: for c in config:

View File

@ -1,16 +1,17 @@
import math import math
from scipy import integrate
import torch import torch
from torch import nn
import torchsde import torchsde
from scipy import integrate
from torch import nn
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
from . import utils
from . import deis from . import deis
from . import utils
from .. import model_patcher from .. import model_patcher
from .. import model_sampling from .. import model_sampling
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])
@ -274,6 +275,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
def linear_multistep_coeff(order, t, i, j): def linear_multistep_coeff(order, t, i, j):
if order - 1 > i: if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}') raise ValueError(f'Order {order} too high for step {i}')
def fn(tau): def fn(tau):
prod = 1. prod = 1.
for k in range(order): for k in range(order):
@ -281,6 +283,7 @@ def linear_multistep_coeff(order, t, i, j):
continue continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@ -306,6 +309,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
class PIDStepSizeController: class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control.""" """A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
self.h = h self.h = h
self.b1 = (pcoeff + icoeff + dcoeff) / order self.b1 = (pcoeff + icoeff + dcoeff) / order
@ -552,17 +556,17 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log() lambda_fn = lambda sigma: ((1 - sigma) / sigma).log()
# logged_x = x.unsqueeze(0) # logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1] alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
@ -590,10 +594,11 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition # Noise addition
if sigmas[i + 1] > 0 and eta > 0: if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic).""" """DPM-Solver++ (stochastic)."""
@ -665,6 +670,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
old_denoised = denoised old_denoised = denoised
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE.""" """DPM-Solver++(2M) SDE."""
@ -713,6 +719,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
h_last = h if h is not None else h_last h_last = h if h is not None else h_last
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(3M) SDE.""" """DPM-Solver++(3M) SDE."""
@ -766,6 +773,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
h_1, h_2 = h, h_1 h_1, h_2 = h, h_1
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1: if len(sigmas) <= 1:
@ -775,6 +783,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1: if len(sigmas) <= 1:
@ -784,6 +793,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1: if len(sigmas) <= 1:
@ -804,6 +814,7 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -823,6 +834,7 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad() @torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@ -839,7 +851,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
return x return x
@torch.no_grad() @torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
""" """
@ -893,12 +904,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0] w = 2 * sigmas[0]
w2 = sigmas[i+1]/w w2 = sigmas[i + 1] / w
w1 = 1 - w2 w1 = 1 - w2
d_prime = d * w1 + d_2 * w2 d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt x = x + d_prime * dt
else: else:
@ -922,8 +932,8 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
return x return x
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license # under Apache 2 license
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
@ -943,27 +953,28 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1) order = min(max_order, i + 1)
if order == 1: # First Euler step. if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point. elif order == 2: # Use one history point.
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
elif order == 3: # Use two history points. elif order == 3: # Use two history points.
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
elif order == 4: # Use three history points. elif order == 4: # Use three history points.
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
if len(buffer_model) == max_order - 1: if len(buffer_model) == max_order - 1:
for k in range(max_order - 2): for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1] buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur buffer_model[-1] = d_cur
else: else:
buffer_model.append(d_cur) buffer_model.append(d_cur)
return x_next return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
# under Apache 2 license
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
@ -984,32 +995,32 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_cur = (x_cur - denoised) / t_cur d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1) order = min(max_order, i + 1)
if order == 1: # First Euler step. if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point. elif order == 2: # Use one history point.
h_n = (t_next - t_cur) h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1]) h_n_1 = (t_cur - t_steps[i - 1])
coeff1 = (2 + (h_n / h_n_1)) / 2 coeff1 = (2 + (h_n / h_n_1)) / 2
coeff2 = -(h_n / h_n_1) / 2 coeff2 = -(h_n / h_n_1) / 2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1]) x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
elif order == 3: # Use two history points. elif order == 3: # Use two history points.
h_n = (t_next - t_cur) h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1]) h_n_1 = (t_cur - t_steps[i - 1])
h_n_2 = (t_steps[i-1] - t_steps[i-2]) h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
coeff3 = temp * h_n_1 / h_n_2 coeff3 = temp * h_n_1 / h_n_2
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2]) x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
elif order == 4: # Use three history points. elif order == 4: # Use three history points.
h_n = (t_next - t_cur) h_n = (t_next - t_cur)
h_n_1 = (t_cur - t_steps[i-1]) h_n_1 = (t_cur - t_steps[i - 1])
h_n_2 = (t_steps[i-1] - t_steps[i-2]) h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
h_n_3 = (t_steps[i-2] - t_steps[i-3]) h_n_3 = (t_steps[i - 2] - t_steps[i - 3])
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \ temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2 coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2 coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2 coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
@ -1018,15 +1029,16 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
if len(buffer_model) == max_order - 1: if len(buffer_model) == max_order - 1:
for k in range(max_order - 2): for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1] buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur.detach() buffer_model[-1] = d_cur.detach()
else: else:
buffer_model.append(d_cur.detach()) buffer_model.append(d_cur.detach())
return x_next return x_next
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
#under Apache 2 license # From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
# under Apache 2 license
@torch.no_grad() @torch.no_grad()
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'): def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@ -1050,36 +1062,38 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
d_cur = (x_cur - denoised) / t_cur d_cur = (x_cur - denoised) / t_cur
order = min(max_order, i+1) order = min(max_order, i + 1)
if t_next <= 0: if t_next <= 0:
order = 1 order = 1
if order == 1: # First Euler step. if order == 1: # First Euler step.
x_next = x_cur + (t_next - t_cur) * d_cur x_next = x_cur + (t_next - t_cur) * d_cur
elif order == 2: # Use one history point. elif order == 2: # Use one history point.
coeff_cur, coeff_prev1 = coeff_list[i] coeff_cur, coeff_prev1 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
elif order == 3: # Use two history points. elif order == 3: # Use two history points.
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i] coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
elif order == 4: # Use three history points. elif order == 4: # Use three history points.
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i] coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
if len(buffer_model) == max_order - 1: if len(buffer_model) == max_order - 1:
for k in range(max_order - 2): for k in range(max_order - 2):
buffer_model[k] = buffer_model[k+1] buffer_model[k] = buffer_model[k + 1]
buffer_model[-1] = d_cur.detach() buffer_model[-1] = d_cur.detach()
else: else:
buffer_model.append(d_cur.detach()) buffer_model.append(d_cur.detach())
return x_next return x_next
@torch.no_grad() @torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
temp = [0] temp = [0]
def post_cfg_function(args): def post_cfg_function(args):
temp[0] = args["uncond_denoised"] temp[0] = args["uncond_denoised"]
return args["denoised"] return args["denoised"]
@ -1099,6 +1113,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
x = denoised + d * sigmas[i + 1] x = denoised + d * sigmas[i + 1]
return x return x
@torch.no_grad() @torch.no_grad()
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps."""
@ -1106,6 +1121,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0] temp = [0]
def post_cfg_function(args): def post_cfg_function(args):
temp[0] = args["uncond_denoised"] temp[0] = args["uncond_denoised"]
return args["denoised"] return args["denoised"]
@ -1126,6 +1142,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x return x
@torch.no_grad() @torch.no_grad()
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps.""" """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
@ -1133,12 +1151,13 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
temp = [0] temp = [0]
def post_cfg_function(args): def post_cfg_function(args):
temp[0] = args["uncond_denoised"] temp[0] = args["uncond_denoised"]
return args["denoised"] return args["denoised"]
model_options = extra_args.get("model_options", {}).copy() model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp() sigma_fn = lambda t: t.neg().exp()

View File

@ -0,0 +1,52 @@
import logging
import torch
from comfy.model_patcher import ModelPatcher
DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
"optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": ("STRING", {"default": "inductor"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
"fullgraph": fullgraph,
"dynamic": dynamic,
"backend": backend
}
if isinstance(model, ModelPatcher):
m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
return (m,)
elif isinstance(model, torch.nn.Module):
return torch.compile(model=model, **compile_kwargs),
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -1,21 +0,0 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

View File

@ -118,7 +118,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
for _ in range(max_retries): for _ in range(max_retries):
try: try:
async with session.get(url, timeout=1) as response: async with session.get(url, timeout=1) as response:
if response.status == 200 and await response.text() == "OK": if response.status == 200:
return True return True
except Exception as exc_info: except Exception as exc_info:
pass pass

View File

@ -59,7 +59,7 @@ class _ProgressHandler(ServerStub):
self.tuples.append((event, data, sid)) self.tuples.append((event, data, sid))
class Client: class ComfyClient:
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler): def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
self.embedded_client = embedded_client self.embedded_client = embedded_client
self.progress_handler = progress_handler self.progress_handler = progress_handler
@ -105,7 +105,7 @@ class TestExecution:
(0,), (0,),
(100,), (100,),
]) ])
async def client(self, request) -> Client: async def client(self, request) -> ComfyClient:
from comfy.cmd.execution import nodes from comfy.cmd.execution import nodes
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
@ -115,13 +115,13 @@ class TestExecution:
configuration.cache_lru = lru_size configuration.cache_lru = lru_size
progress_handler = _ProgressHandler() progress_handler = _ProgressHandler()
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
yield Client(embedded_client, progress_handler) yield ComfyClient(embedded_client, progress_handler)
@fixture @fixture
def builder(self, request): def builder(self, request):
yield GraphBuilder(prefix=request.node.name) yield GraphBuilder(prefix=request.node.name)
async def test_lazy_input(self, client: Client, builder: GraphBuilder): async def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -138,7 +138,7 @@ class TestExecution:
assert result.did_run(mask) assert result.did_run(mask)
assert result.did_run(lazy_mix) assert result.did_run(lazy_mix)
async def test_full_cache(self, client: Client, builder: GraphBuilder): async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@ -152,7 +152,7 @@ class TestExecution:
for node_id, node in g.nodes.items(): for node_id, node in g.nodes.items():
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
async def test_partial_cache(self, client: Client, builder: GraphBuilder): async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
@ -167,7 +167,7 @@ class TestExecution:
assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached"
async def test_error(self, client: Client, builder: GraphBuilder): async def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Different size of the two images # Different size of the two images
@ -188,7 +188,7 @@ class TestExecution:
("foo", True), ("foo", True),
(5.0, False), (5.0, False),
]) ])
async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder): async def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0)) g.node("SaveImage", images=validation1.out(0))
@ -203,7 +203,7 @@ class TestExecution:
("StubInt", 5), ("StubInt", 5),
("StubFloat", 5.0) ("StubFloat", 5.0)
]) ])
async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder): async def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
stub = g.node(test_type, value=test_value) stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
@ -216,7 +216,7 @@ class TestExecution:
("StubInt", 5, True), ("StubInt", 5, True),
("StubFloat", 5.0, False) ("StubFloat", 5.0, False)
]) ])
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
stub = g.node(test_type, value=test_value) stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
@ -232,7 +232,7 @@ class TestExecution:
("StubInt", 5, True), ("StubInt", 5, True),
("StubFloat", 5.0, False) ("StubFloat", 5.0, False)
]) ])
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
stub = g.node(test_type, value=test_value) stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
@ -248,7 +248,7 @@ class TestExecution:
("StubInt", 5, True), ("StubInt", 5, True),
("StubFloat", 5.0, False) ("StubFloat", 5.0, False)
]) ])
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
stub = g.node(test_type, value=test_value) stub = g.node(test_type, value=test_value)
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
@ -260,7 +260,7 @@ class TestExecution:
else: else:
await client.run(g) await client.run(g)
async def test_cycle_error(self, client: Client, builder: GraphBuilder): async def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -274,7 +274,7 @@ class TestExecution:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await client.run(g) await client.run(g)
async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder): async def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -289,7 +289,7 @@ class TestExecution:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
async def test_custom_is_changed(self, client: Client, builder: GraphBuilder): async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
# Creating the nodes in this specific order previously caused a bug # Creating the nodes in this specific order previously caused a bug
save = g.node("SaveImage") save = g.node("SaveImage")
@ -309,7 +309,7 @@ class TestExecution:
assert result3.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run"
assert result4.did_run(is_changed), "is_changed should not have been cached" assert result4.did_run(is_changed), "is_changed should not have been cached"
async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder): async def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -323,7 +323,7 @@ class TestExecution:
expected = 255 // 4 expected = 255 // 4
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
async def test_for_loop(self, client: Client, builder: GraphBuilder): async def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
iterations = 4 iterations = 4
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -342,7 +342,7 @@ class TestExecution:
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
assert result.did_run(is_changed) assert result.did_run(is_changed)
async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder): async def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3) val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0)) mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
@ -361,7 +361,7 @@ class TestExecution:
for i in range(3): for i in range(3):
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder): async def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0) val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1) mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
@ -378,7 +378,7 @@ class TestExecution:
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5" assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0" assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
async def test_missing_node_error(self, client: Client, builder: GraphBuilder): async def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
@ -397,7 +397,7 @@ class TestExecution:
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
await client.run(g) await client.run(g)
async def test_output_reuse(self, client: Client, builder: GraphBuilder): async def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
@ -410,9 +410,8 @@ class TestExecution:
assert len(images1) == 1, "Should have 1 image" assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED` # This tests that only constant outputs are used in the call to `IS_CHANGED`
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder): async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
@ -433,7 +432,7 @@ class TestExecution:
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker # This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node, # as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
# only that one entry in the list is blocked. # only that one entry in the list is blocked.
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder): async def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
g = builder g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@ -445,11 +444,11 @@ class TestExecution:
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0)) int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==") compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
list_output = g.node("TestMakeListNode", value1=blocker.out(0)) list_output = g.node("TestMakeListNode", value1=blocker.out(0))
output = g.node("PreviewImage", images=list_output.out(0)) output = g.node("PreviewImage", images=list_output.out(0))
result = client.run(g) result = await client.run(g)
assert result.did_run(output), "The execution should have run" assert result.did_run(output), "The execution should have run"
images = result.get_images(output) images = result.get_images(output)
assert len(images) == 2, "Should have 2 images" assert len(images) == 2, "Should have 2 images"

View File

@ -1,11 +1,13 @@
### 🗻 This file is created through the spirit of Mount Fuji at its peak ### 🗻 This file is created through the spirit of Mount Fuji at its peak
# TODO(yoland): clean up this after I get back down # TODO(yoland): clean up this after I get back down
import pytest
import os import os
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
import folder_paths import pytest
from comfy.cmd import folder_paths
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
@ -19,21 +21,25 @@ def test_get_directory_by_type():
assert folder_paths.get_directory_by_type("output") == test_dir assert folder_paths.get_directory_by_type("output") == test_dir
assert folder_paths.get_directory_by_type("invalid") is None assert folder_paths.get_directory_by_type("invalid") is None
def test_annotated_filepath(): def test_annotated_filepath():
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None) assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory()) assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory()) assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory()) assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
def test_get_annotated_filepath(): def test_get_annotated_filepath():
default_dir = "/default/dir" default_dir = "/default/dir"
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt") assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt") assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
def test_add_model_folder_path(): def test_add_model_folder_path():
folder_paths.add_model_folder_path("test_folder", "/test/path") folder_paths.add_model_folder_path("test_folder", "/test/path")
assert "/test/path" in folder_paths.get_folder_paths("test_folder") assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir): def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir")) os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close() open(os.path.join(temp_dir, "file1.txt"), "w").close()
@ -43,12 +49,14 @@ def test_recursive_search(temp_dir):
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")} assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions(): def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"] files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"] assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"] assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
assert folder_paths.filter_files_extensions(files, []) == files assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search") @patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths") @patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
@ -56,6 +64,7 @@ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {}) mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir): def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir): with patch("folder_paths.output_directory", temp_dir):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
@ -63,4 +72,4 @@ def test_get_save_image_path(temp_dir):
assert filename == "test" assert filename == "test"
assert counter == 1 assert counter == 1
assert subfolder == "" assert subfolder == ""
assert filename_prefix == "test" assert filename_prefix == "test"

View File

@ -1,13 +1,16 @@
import pytest
import os import os
import tempfile import tempfile
from folder_paths import filter_files_content_types
import pytest
from comfy.cmd.folder_paths import filter_files_content_types
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def file_extensions(): def file_extensions():
return { return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'], 'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'], 'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] 'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
} }
@ -49,4 +52,4 @@ def test_handles_no_extension():
def test_handles_no_files(): def test_handles_no_files():
files = [] files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == [] assert filter_files_content_types(files, ["image", "audio", "video"]) == []

View File

View File

@ -1,10 +1,12 @@
import pytest
import yaml
import os import os
from unittest.mock import Mock, patch, mock_open from unittest.mock import Mock, patch, mock_open
from utils.extra_config import load_extra_path_config import pytest
import folder_paths import yaml
from comfy.cmd import folder_paths
from comfy.extra_config import load_extra_path_config
@pytest.fixture @pytest.fixture
def mock_yaml_content(): def mock_yaml_content():
@ -15,10 +17,12 @@ def mock_yaml_content():
} }
} }
@pytest.fixture @pytest.fixture
def mock_expanded_home(): def mock_expanded_home():
return '/home/user' return '/home/user'
@pytest.fixture @pytest.fixture
def yaml_config_with_appdata(): def yaml_config_with_appdata():
return """ return """
@ -27,40 +31,47 @@ def yaml_config_with_appdata():
checkpoints: 'models/checkpoints' checkpoints: 'models/checkpoints'
""" """
@pytest.fixture @pytest.fixture
def mock_yaml_content_appdata(yaml_config_with_appdata): def mock_yaml_content_appdata(yaml_config_with_appdata):
return yaml.safe_load(yaml_config_with_appdata) return yaml.safe_load(yaml_config_with_appdata)
@pytest.fixture @pytest.fixture
def mock_expandvars_appdata(): def mock_expandvars_appdata():
mock = Mock() mock = Mock()
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming') mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
return mock return mock
@pytest.fixture @pytest.fixture
def mock_add_model_folder_path(): def mock_add_model_folder_path():
return Mock() return Mock()
@pytest.fixture @pytest.fixture
def mock_expanduser(mock_expanded_home): def mock_expanduser(mock_expanded_home):
def _expanduser(path): def _expanduser(path):
if path.startswith('~/'): if path.startswith('~/'):
return os.path.join(mock_expanded_home, path[2:]) return os.path.join(mock_expanded_home, path[2:])
return path return path
return _expanduser return _expanduser
@pytest.fixture @pytest.fixture
def mock_yaml_safe_load(mock_yaml_content): def mock_yaml_safe_load(mock_yaml_content):
return Mock(return_value=mock_yaml_content) return Mock(return_value=mock_yaml_content)
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content") @patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
def test_load_extra_model_paths_expands_userpath( def test_load_extra_model_paths_expands_userpath(
mock_file, mock_file,
monkeypatch, monkeypatch,
mock_add_model_folder_path, mock_add_model_folder_path,
mock_expanduser, mock_expanduser,
mock_yaml_safe_load, mock_yaml_safe_load,
mock_expanded_home mock_expanded_home
): ):
# Attach mocks used by load_extra_path_config # Attach mocks used by load_extra_path_config
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path) monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
@ -75,7 +86,7 @@ def test_load_extra_model_paths_expands_userpath(
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments # Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call assert actual_call.args == expected_call
@ -86,14 +97,15 @@ def test_load_extra_model_paths_expands_userpath(
# Check if open was called with the correct file path # Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
@patch('builtins.open', new_callable=mock_open) @patch('builtins.open', new_callable=mock_open)
def test_load_extra_model_paths_expands_appdata( def test_load_extra_model_paths_expands_appdata(
mock_file, mock_file,
monkeypatch, monkeypatch,
mock_add_model_folder_path, mock_add_model_folder_path,
mock_expandvars_appdata, mock_expandvars_appdata,
yaml_config_with_appdata, yaml_config_with_appdata,
mock_yaml_content_appdata mock_yaml_content_appdata
): ):
# Set the mock_file to return yaml with appdata as a variable # Set the mock_file to return yaml with appdata as a variable
mock_file.return_value.read.return_value = yaml_config_with_appdata mock_file.return_value.read.return_value = yaml_config_with_appdata
@ -115,7 +127,7 @@ def test_load_extra_model_paths_expands_appdata(
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check the base path variable was expanded # Check the base path variable was expanded
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call assert actual_call.args == expected_call