mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 09:10:16 +08:00
Fix tests, improve distributed worker health check, add torch compile options
This commit is contained in:
parent
ffb4ed9cf2
commit
83b2f0174c
@ -1,25 +1,4 @@
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
from . import folder_paths
|
||||
from ..extra_config import load_extra_path_config
|
||||
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
for c in config:
|
||||
conf = config[c]
|
||||
if conf is None:
|
||||
continue
|
||||
base_path = None
|
||||
if "base_path" in conf:
|
||||
base_path = conf.pop("base_path")
|
||||
for x in conf:
|
||||
for y in conf[x].split("\n"):
|
||||
if len(y) == 0:
|
||||
continue
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
logging.info(f"Adding extra search path {x} ({full_path})")
|
||||
folder_paths.add_model_folder_path(x, full_path)
|
||||
return load_extra_path_config(yaml_path)
|
||||
|
||||
@ -349,7 +349,7 @@ def invalidate_cache(folder_name):
|
||||
_filename_list_cache.pop(folder_name, None)
|
||||
|
||||
|
||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
||||
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
|
||||
"""
|
||||
Example:
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
|
||||
@ -41,7 +41,14 @@ class DistributedPromptWorker:
|
||||
self._health_check_site: Optional[web.TCPSite] = None
|
||||
|
||||
async def _health_check(self, request):
|
||||
return web.Response(text="OK", content_type="text/plain")
|
||||
if self._connection is None:
|
||||
return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503)
|
||||
|
||||
is_healthy = await self._is_connection_healthy()
|
||||
if is_healthy:
|
||||
return web.Response(text="HEALTHY", status=200)
|
||||
else:
|
||||
return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503)
|
||||
|
||||
async def _start_health_check_server(self):
|
||||
app = web.Application()
|
||||
@ -85,9 +92,27 @@ class DistributedPromptWorker:
|
||||
await self.on_did_complete_work_item(request_obj, reply)
|
||||
return asdict(reply)
|
||||
|
||||
async def _is_connection_healthy(self):
|
||||
if self._connection is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
not self._connection.is_closed
|
||||
and self._connection.connected.is_set()
|
||||
and await self._check_connection_ready()
|
||||
)
|
||||
|
||||
async def _check_connection_ready(self):
|
||||
try:
|
||||
await asyncio.wait_for(self._connection.ready(), timeout=1.0)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
@tracer.start_as_current_span("Initialize Prompt Worker")
|
||||
async def init(self):
|
||||
await self._exit_stack.__aenter__()
|
||||
await self._start_health_check_server()
|
||||
try:
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
except AMQPConnectionError as connection_error:
|
||||
@ -102,7 +127,6 @@ class DistributedPromptWorker:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||
await self._start_health_check_server()
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
await self.init()
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import os
|
||||
import yaml
|
||||
import folder_paths
|
||||
import logging
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_extra_path_config(yaml_path):
|
||||
from .cmd import folder_paths
|
||||
|
||||
with open(yaml_path, 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
for c in config:
|
||||
@ -1,16 +1,17 @@
|
||||
import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from scipy import integrate
|
||||
from torch import nn
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
from . import utils
|
||||
from .. import model_patcher
|
||||
from .. import model_sampling
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
@ -274,6 +275,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
@ -281,6 +283,7 @@ def linear_multistep_coeff(order, t, i, j):
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@ -306,6 +309,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||
|
||||
class PIDStepSizeController:
|
||||
"""A PID controller for ODE adaptive step size control."""
|
||||
|
||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||
self.h = h
|
||||
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
||||
@ -552,17 +556,17 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
lambda_fn = lambda sigma: ((1 - sigma) / sigma).log()
|
||||
|
||||
# logged_x = x.unsqueeze(0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
@ -590,10 +594,11 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
@ -665,6 +670,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
@ -713,6 +719,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
h_last = h if h is not None else h_last
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
@ -766,6 +773,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if len(sigmas) <= 1:
|
||||
@ -775,6 +783,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
if len(sigmas) <= 1:
|
||||
@ -784,6 +793,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
if len(sigmas) <= 1:
|
||||
@ -804,6 +814,7 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||
return mu
|
||||
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -823,6 +834,7 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@ -839,7 +851,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""
|
||||
@ -893,12 +904,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
|
||||
w = 2 * sigmas[0]
|
||||
w2 = sigmas[i+1]/w
|
||||
w2 = sigmas[i + 1] / w
|
||||
w1 = 1 - w2
|
||||
|
||||
d_prime = d * w1 + d_2 * w2
|
||||
|
||||
|
||||
x = x + d_prime * dt
|
||||
|
||||
else:
|
||||
@ -922,8 +932,8 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
return x
|
||||
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
# under Apache 2 license
|
||||
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
@ -943,27 +953,28 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
order = min(max_order, i + 1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
elif order == 2: # Use one history point.
|
||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||
elif order == 3: # Use two history points.
|
||||
elif order == 3: # Use two history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
||||
elif order == 4: # Use three history points.
|
||||
elif order == 4: # Use three history points.
|
||||
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[k] = buffer_model[k + 1]
|
||||
buffer_model[-1] = d_cur
|
||||
else:
|
||||
buffer_model.append(d_cur)
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
|
||||
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
# under Apache 2 license
|
||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
@ -984,32 +995,32 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
if order == 1: # First Euler step.
|
||||
order = min(max_order, i + 1)
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
elif order == 2: # Use one history point.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_1 = (t_cur - t_steps[i - 1])
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2
|
||||
coeff2 = -(h_n / h_n_1) / 2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
||||
elif order == 3: # Use two history points.
|
||||
elif order == 3: # Use two history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
h_n_1 = (t_cur - t_steps[i - 1])
|
||||
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
|
||||
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
||||
coeff3 = temp * h_n_1 / h_n_2
|
||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
||||
elif order == 4: # Use three history points.
|
||||
elif order == 4: # Use three history points.
|
||||
h_n = (t_next - t_cur)
|
||||
h_n_1 = (t_cur - t_steps[i-1])
|
||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
||||
h_n_3 = (t_steps[i-2] - t_steps[i-3])
|
||||
h_n_1 = (t_cur - t_steps[i - 1])
|
||||
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
|
||||
h_n_3 = (t_steps[i - 2] - t_steps[i - 3])
|
||||
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
||||
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
||||
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
||||
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
||||
@ -1018,15 +1029,16 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[k] = buffer_model[k + 1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
|
||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
#under Apache 2 license
|
||||
|
||||
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||
# under Apache 2 license
|
||||
@torch.no_grad()
|
||||
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
@ -1050,36 +1062,38 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
|
||||
d_cur = (x_cur - denoised) / t_cur
|
||||
|
||||
order = min(max_order, i+1)
|
||||
order = min(max_order, i + 1)
|
||||
if t_next <= 0:
|
||||
order = 1
|
||||
|
||||
if order == 1: # First Euler step.
|
||||
if order == 1: # First Euler step.
|
||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||
elif order == 2: # Use one history point.
|
||||
elif order == 2: # Use one history point.
|
||||
coeff_cur, coeff_prev1 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
||||
elif order == 3: # Use two history points.
|
||||
elif order == 3: # Use two history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
||||
elif order == 4: # Use three history points.
|
||||
elif order == 4: # Use three history points.
|
||||
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
||||
|
||||
if len(buffer_model) == max_order - 1:
|
||||
for k in range(max_order - 2):
|
||||
buffer_model[k] = buffer_model[k+1]
|
||||
buffer_model[k] = buffer_model[k + 1]
|
||||
buffer_model[-1] = d_cur.detach()
|
||||
else:
|
||||
buffer_model.append(d_cur.detach())
|
||||
|
||||
return x_next
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
temp = [0]
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
@ -1099,6 +1113,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = denoised + d * sigmas[i + 1]
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
@ -1106,6 +1121,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
@ -1126,6 +1142,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
@ -1133,12 +1151,13 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
|
||||
temp = [0]
|
||||
|
||||
def post_cfg_function(args):
|
||||
temp[0] = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
|
||||
52
comfy_extras/nodes/nodes_torch_compile.py
Normal file
52
comfy_extras/nodes/nodes_torch_compile.py
Normal file
@ -0,0 +1,52 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
},
|
||||
"optional": {
|
||||
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
||||
"fullgraph": ("BOOLEAN", {"default": False}),
|
||||
"dynamic": ("BOOLEAN", {"default": False}),
|
||||
"backend": ("STRING", {"default": "inductor"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
||||
if object_patch is None:
|
||||
object_patch = DIFFUSION_MODEL
|
||||
compile_kwargs = {
|
||||
"fullgraph": fullgraph,
|
||||
"dynamic": dynamic,
|
||||
"backend": backend
|
||||
}
|
||||
if isinstance(model, ModelPatcher):
|
||||
m = model.clone()
|
||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||
return (m,)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return torch.compile(model=model, **compile_kwargs),
|
||||
else:
|
||||
logging.warning("Encountered a model that cannot be compiled")
|
||||
return model,
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TorchCompileModel": TorchCompileModel,
|
||||
}
|
||||
@ -1,21 +0,0 @@
|
||||
import torch
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TorchCompileModel": TorchCompileModel,
|
||||
}
|
||||
@ -118,7 +118,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
async with session.get(url, timeout=1) as response:
|
||||
if response.status == 200 and await response.text() == "OK":
|
||||
if response.status == 200:
|
||||
return True
|
||||
except Exception as exc_info:
|
||||
pass
|
||||
|
||||
@ -59,7 +59,7 @@ class _ProgressHandler(ServerStub):
|
||||
self.tuples.append((event, data, sid))
|
||||
|
||||
|
||||
class Client:
|
||||
class ComfyClient:
|
||||
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
|
||||
self.embedded_client = embedded_client
|
||||
self.progress_handler = progress_handler
|
||||
@ -105,7 +105,7 @@ class TestExecution:
|
||||
(0,),
|
||||
(100,),
|
||||
])
|
||||
async def client(self, request) -> Client:
|
||||
async def client(self, request) -> ComfyClient:
|
||||
from comfy.cmd.execution import nodes
|
||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
@ -115,13 +115,13 @@ class TestExecution:
|
||||
configuration.cache_lru = lru_size
|
||||
progress_handler = _ProgressHandler()
|
||||
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
|
||||
yield Client(embedded_client, progress_handler)
|
||||
yield ComfyClient(embedded_client, progress_handler)
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
async def test_lazy_input(self, client: Client, builder: GraphBuilder):
|
||||
async def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -138,7 +138,7 @@ class TestExecution:
|
||||
assert result.did_run(mask)
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
async def test_full_cache(self, client: Client, builder: GraphBuilder):
|
||||
async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -152,7 +152,7 @@ class TestExecution:
|
||||
for node_id, node in g.nodes.items():
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
|
||||
async def test_partial_cache(self, client: Client, builder: GraphBuilder):
|
||||
async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -167,7 +167,7 @@ class TestExecution:
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
|
||||
async def test_error(self, client: Client, builder: GraphBuilder):
|
||||
async def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Different size of the two images
|
||||
@ -188,7 +188,7 @@ class TestExecution:
|
||||
("foo", True),
|
||||
(5.0, False),
|
||||
])
|
||||
async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
async def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
||||
g.node("SaveImage", images=validation1.out(0))
|
||||
@ -203,7 +203,7 @@ class TestExecution:
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
])
|
||||
async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder):
|
||||
async def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
||||
@ -216,7 +216,7 @@ class TestExecution:
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
||||
@ -232,7 +232,7 @@ class TestExecution:
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
||||
@ -248,7 +248,7 @@ class TestExecution:
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
||||
@ -260,7 +260,7 @@ class TestExecution:
|
||||
else:
|
||||
await client.run(g)
|
||||
|
||||
async def test_cycle_error(self, client: Client, builder: GraphBuilder):
|
||||
async def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -274,7 +274,7 @@ class TestExecution:
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
|
||||
async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder):
|
||||
async def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -289,7 +289,7 @@ class TestExecution:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||
|
||||
async def test_custom_is_changed(self, client: Client, builder: GraphBuilder):
|
||||
async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Creating the nodes in this specific order previously caused a bug
|
||||
save = g.node("SaveImage")
|
||||
@ -309,7 +309,7 @@ class TestExecution:
|
||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder):
|
||||
async def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -323,7 +323,7 @@ class TestExecution:
|
||||
expected = 255 // 4
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
|
||||
async def test_for_loop(self, client: Client, builder: GraphBuilder):
|
||||
async def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
iterations = 4
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
@ -342,7 +342,7 @@ class TestExecution:
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
assert result.did_run(is_changed)
|
||||
|
||||
async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder):
|
||||
async def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
|
||||
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
|
||||
@ -361,7 +361,7 @@ class TestExecution:
|
||||
for i in range(3):
|
||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||
|
||||
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder):
|
||||
async def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
||||
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
|
||||
@ -378,7 +378,7 @@ class TestExecution:
|
||||
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
|
||||
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
|
||||
|
||||
async def test_missing_node_error(self, client: Client, builder: GraphBuilder):
|
||||
async def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -397,7 +397,7 @@ class TestExecution:
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
await client.run(g)
|
||||
|
||||
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
|
||||
async def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
@ -410,9 +410,8 @@ class TestExecution:
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder):
|
||||
async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
@ -433,7 +432,7 @@ class TestExecution:
|
||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||
# only that one entry in the list is blocked.
|
||||
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||
async def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -445,11 +444,11 @@ class TestExecution:
|
||||
int_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0), value3=int3.out(0))
|
||||
compare = g.node("TestIntConditions", a=int_list.out(0), b=2, operation="==")
|
||||
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
|
||||
|
||||
|
||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||
output = g.node("PreviewImage", images=list_output.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
result = await client.run(g)
|
||||
assert result.did_run(output), "The execution should have run"
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||
# TODO(yoland): clean up this after I get back down
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import folder_paths
|
||||
import pytest
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
@ -19,21 +21,25 @@ def test_get_directory_by_type():
|
||||
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||
assert folder_paths.get_directory_by_type("invalid") is None
|
||||
|
||||
|
||||
def test_annotated_filepath():
|
||||
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||
|
||||
|
||||
def test_get_annotated_filepath():
|
||||
default_dir = "/default/dir"
|
||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||
|
||||
|
||||
def test_add_model_folder_path():
|
||||
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||
|
||||
|
||||
def test_recursive_search(temp_dir):
|
||||
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||
@ -43,12 +49,14 @@ def test_recursive_search(temp_dir):
|
||||
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||
assert len(dirs) == 2 # temp_dir and subdir
|
||||
|
||||
|
||||
def test_filter_files_extensions():
|
||||
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||
assert folder_paths.filter_files_extensions(files, []) == files
|
||||
|
||||
|
||||
@patch("folder_paths.recursive_search")
|
||||
@patch("folder_paths.folder_names_and_paths")
|
||||
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||
@ -56,6 +64,7 @@ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||
|
||||
|
||||
def test_get_save_image_path(temp_dir):
|
||||
with patch("folder_paths.output_directory", temp_dir):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||
@ -63,4 +72,4 @@ def test_get_save_image_path(temp_dir):
|
||||
assert filename == "test"
|
||||
assert counter == 1
|
||||
assert subfolder == ""
|
||||
assert filename_prefix == "test"
|
||||
assert filename_prefix == "test"
|
||||
@ -1,13 +1,16 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from folder_paths import filter_files_content_types
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def file_extensions():
|
||||
return {
|
||||
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
|
||||
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
|
||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
|
||||
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
||||
}
|
||||
|
||||
@ -49,4 +52,4 @@ def test_handles_no_extension():
|
||||
|
||||
def test_handles_no_files():
|
||||
files = []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
0
tests/unit/utils/__init__.py
Normal file
0
tests/unit/utils/__init__.py
Normal file
@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
import yaml
|
||||
import os
|
||||
from unittest.mock import Mock, patch, mock_open
|
||||
|
||||
from utils.extra_config import load_extra_path_config
|
||||
import folder_paths
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.extra_config import load_extra_path_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content():
|
||||
@ -15,10 +17,12 @@ def mock_yaml_content():
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanded_home():
|
||||
return '/home/user'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yaml_config_with_appdata():
|
||||
return """
|
||||
@ -27,40 +31,47 @@ def yaml_config_with_appdata():
|
||||
checkpoints: 'models/checkpoints'
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||
return yaml.safe_load(yaml_config_with_appdata)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expandvars_appdata():
|
||||
mock = Mock()
|
||||
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_add_model_folder_path():
|
||||
return Mock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_expanduser(mock_expanded_home):
|
||||
def _expanduser(path):
|
||||
if path.startswith('~/'):
|
||||
return os.path.join(mock_expanded_home, path[2:])
|
||||
return path
|
||||
|
||||
return _expanduser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_yaml_safe_load(mock_yaml_content):
|
||||
return Mock(return_value=mock_yaml_content)
|
||||
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||
def test_load_extra_model_paths_expands_userpath(
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expanduser,
|
||||
mock_yaml_safe_load,
|
||||
mock_expanded_home
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expanduser,
|
||||
mock_yaml_safe_load,
|
||||
mock_expanded_home
|
||||
):
|
||||
# Attach mocks used by load_extra_path_config
|
||||
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||
@ -75,7 +86,7 @@ def test_load_extra_model_paths_expands_userpath(
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
|
||||
# Check if add_model_folder_path was called with the correct arguments
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args == expected_call
|
||||
@ -86,14 +97,15 @@ def test_load_extra_model_paths_expands_userpath(
|
||||
# Check if open was called with the correct file path
|
||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
def test_load_extra_model_paths_expands_appdata(
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expandvars_appdata,
|
||||
yaml_config_with_appdata,
|
||||
mock_yaml_content_appdata
|
||||
mock_file,
|
||||
monkeypatch,
|
||||
mock_add_model_folder_path,
|
||||
mock_expandvars_appdata,
|
||||
yaml_config_with_appdata,
|
||||
mock_yaml_content_appdata
|
||||
):
|
||||
# Set the mock_file to return yaml with appdata as a variable
|
||||
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||
@ -115,7 +127,7 @@ def test_load_extra_model_paths_expands_appdata(
|
||||
]
|
||||
|
||||
assert mock_add_model_folder_path.call_count == len(expected_calls)
|
||||
|
||||
|
||||
# Check the base path variable was expanded
|
||||
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
|
||||
assert actual_call.args == expected_call
|
||||
Loading…
Reference in New Issue
Block a user