mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Fix tests, improve distributed worker health check, add torch compile options
This commit is contained in:
parent
ffb4ed9cf2
commit
83b2f0174c
@ -1,25 +1,4 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
import logging
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
def load_extra_path_config(yaml_path):
|
||||||
from . import folder_paths
|
from ..extra_config import load_extra_path_config
|
||||||
|
|
||||||
with open(yaml_path, 'r') as stream:
|
return load_extra_path_config(yaml_path)
|
||||||
config = yaml.safe_load(stream)
|
|
||||||
for c in config:
|
|
||||||
conf = config[c]
|
|
||||||
if conf is None:
|
|
||||||
continue
|
|
||||||
base_path = None
|
|
||||||
if "base_path" in conf:
|
|
||||||
base_path = conf.pop("base_path")
|
|
||||||
for x in conf:
|
|
||||||
for y in conf[x].split("\n"):
|
|
||||||
if len(y) == 0:
|
|
||||||
continue
|
|
||||||
full_path = y
|
|
||||||
if base_path is not None:
|
|
||||||
full_path = os.path.join(base_path, full_path)
|
|
||||||
logging.info(f"Adding extra search path {x} ({full_path})")
|
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
|
||||||
|
|||||||
@ -349,7 +349,7 @@ def invalidate_cache(folder_name):
|
|||||||
_filename_list_cache.pop(folder_name, None)
|
_filename_list_cache.pop(folder_name, None)
|
||||||
|
|
||||||
|
|
||||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
files = os.listdir(folder_paths.get_input_directory())
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
|
|||||||
@ -41,7 +41,14 @@ class DistributedPromptWorker:
|
|||||||
self._health_check_site: Optional[web.TCPSite] = None
|
self._health_check_site: Optional[web.TCPSite] = None
|
||||||
|
|
||||||
async def _health_check(self, request):
|
async def _health_check(self, request):
|
||||||
return web.Response(text="OK", content_type="text/plain")
|
if self._connection is None:
|
||||||
|
return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503)
|
||||||
|
|
||||||
|
is_healthy = await self._is_connection_healthy()
|
||||||
|
if is_healthy:
|
||||||
|
return web.Response(text="HEALTHY", status=200)
|
||||||
|
else:
|
||||||
|
return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503)
|
||||||
|
|
||||||
async def _start_health_check_server(self):
|
async def _start_health_check_server(self):
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
@ -85,9 +92,27 @@ class DistributedPromptWorker:
|
|||||||
await self.on_did_complete_work_item(request_obj, reply)
|
await self.on_did_complete_work_item(request_obj, reply)
|
||||||
return asdict(reply)
|
return asdict(reply)
|
||||||
|
|
||||||
|
async def _is_connection_healthy(self):
|
||||||
|
if self._connection is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
not self._connection.is_closed
|
||||||
|
and self._connection.connected.is_set()
|
||||||
|
and await self._check_connection_ready()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _check_connection_ready(self):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._connection.ready(), timeout=1.0)
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False
|
||||||
|
|
||||||
@tracer.start_as_current_span("Initialize Prompt Worker")
|
@tracer.start_as_current_span("Initialize Prompt Worker")
|
||||||
async def init(self):
|
async def init(self):
|
||||||
await self._exit_stack.__aenter__()
|
await self._exit_stack.__aenter__()
|
||||||
|
await self._start_health_check_server()
|
||||||
try:
|
try:
|
||||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||||
except AMQPConnectionError as connection_error:
|
except AMQPConnectionError as connection_error:
|
||||||
@ -102,7 +127,6 @@ class DistributedPromptWorker:
|
|||||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||||
|
|
||||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||||
await self._start_health_check_server()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||||
await self.init()
|
await self.init()
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
import os
|
|
||||||
import yaml
|
|
||||||
import folder_paths
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
def load_extra_path_config(yaml_path):
|
||||||
|
from .cmd import folder_paths
|
||||||
|
|
||||||
with open(yaml_path, 'r') as stream:
|
with open(yaml_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
for c in config:
|
for c in config:
|
||||||
@ -1,16 +1,17 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from scipy import integrate
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torchsde
|
import torchsde
|
||||||
|
from scipy import integrate
|
||||||
|
from torch import nn
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
from . import utils
|
|
||||||
from . import deis
|
from . import deis
|
||||||
|
from . import utils
|
||||||
from .. import model_patcher
|
from .. import model_patcher
|
||||||
from .. import model_sampling
|
from .. import model_sampling
|
||||||
|
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
|
|
||||||
@ -274,6 +275,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
|||||||
def linear_multistep_coeff(order, t, i, j):
|
def linear_multistep_coeff(order, t, i, j):
|
||||||
if order - 1 > i:
|
if order - 1 > i:
|
||||||
raise ValueError(f'Order {order} too high for step {i}')
|
raise ValueError(f'Order {order} too high for step {i}')
|
||||||
|
|
||||||
def fn(tau):
|
def fn(tau):
|
||||||
prod = 1.
|
prod = 1.
|
||||||
for k in range(order):
|
for k in range(order):
|
||||||
@ -281,6 +283,7 @@ def linear_multistep_coeff(order, t, i, j):
|
|||||||
continue
|
continue
|
||||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||||
return prod
|
return prod
|
||||||
|
|
||||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||||
|
|
||||||
|
|
||||||
@ -306,6 +309,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
|||||||
|
|
||||||
class PIDStepSizeController:
|
class PIDStepSizeController:
|
||||||
"""A PID controller for ODE adaptive step size control."""
|
"""A PID controller for ODE adaptive step size control."""
|
||||||
|
|
||||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||||
self.h = h
|
self.h = h
|
||||||
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
||||||
@ -552,17 +556,17 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
|||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
lambda_fn = lambda sigma: ((1 - sigma) / sigma).log()
|
||||||
|
|
||||||
# logged_x = x.unsqueeze(0)
|
# logged_x = x.unsqueeze(0)
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||||
sigma_down = sigmas[i+1] * downstep_ratio
|
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||||
alpha_ip1 = 1 - sigmas[i+1]
|
alpha_ip1 = 1 - sigmas[i + 1]
|
||||||
alpha_down = 1 - sigma_down
|
alpha_down = 1 - sigma_down
|
||||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
renoise_coeff = (sigmas[i + 1] ** 2 - sigma_down ** 2 * alpha_ip1 ** 2 / alpha_down ** 2) ** 0.5
|
||||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
@ -590,10 +594,11 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
|
|||||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||||
# Noise addition
|
# Noise addition
|
||||||
if sigmas[i + 1] > 0 and eta > 0:
|
if sigmas[i + 1] > 0 and eta > 0:
|
||||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
@ -665,6 +670,7 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
"""DPM-Solver++(2M) SDE."""
|
"""DPM-Solver++(2M) SDE."""
|
||||||
@ -713,6 +719,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
h_last = h if h is not None else h_last
|
h_last = h if h is not None else h_last
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -766,6 +773,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
h_1, h_2 = h, h_1
|
h_1, h_2 = h, h_1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -775,6 +783,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -784,6 +793,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
@ -804,6 +814,7 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
|||||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
|
||||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@ -823,6 +834,7 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
|||||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
@ -839,7 +851,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||||
"""
|
"""
|
||||||
@ -893,12 +904,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||||
|
|
||||||
w = 2 * sigmas[0]
|
w = 2 * sigmas[0]
|
||||||
w2 = sigmas[i+1]/w
|
w2 = sigmas[i + 1] / w
|
||||||
w1 = 1 - w2
|
w1 = 1 - w2
|
||||||
|
|
||||||
d_prime = d * w1 + d_2 * w2
|
d_prime = d * w1 + d_2 * w2
|
||||||
|
|
||||||
|
|
||||||
x = x + d_prime * dt
|
x = x + d_prime * dt
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -922,8 +932,8 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||||
#under Apache 2 license
|
# under Apache 2 license
|
||||||
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
@ -943,27 +953,28 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
|
|
||||||
d_cur = (x_cur - denoised) / t_cur
|
d_cur = (x_cur - denoised) / t_cur
|
||||||
|
|
||||||
order = min(max_order, i+1)
|
order = min(max_order, i + 1)
|
||||||
if order == 1: # First Euler step.
|
if order == 1: # First Euler step.
|
||||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||||
elif order == 2: # Use one history point.
|
elif order == 2: # Use one history point.
|
||||||
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
|
||||||
elif order == 3: # Use two history points.
|
elif order == 3: # Use two history points.
|
||||||
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
|
||||||
elif order == 4: # Use three history points.
|
elif order == 4: # Use three history points.
|
||||||
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
|
||||||
|
|
||||||
if len(buffer_model) == max_order - 1:
|
if len(buffer_model) == max_order - 1:
|
||||||
for k in range(max_order - 2):
|
for k in range(max_order - 2):
|
||||||
buffer_model[k] = buffer_model[k+1]
|
buffer_model[k] = buffer_model[k + 1]
|
||||||
buffer_model[-1] = d_cur
|
buffer_model[-1] = d_cur
|
||||||
else:
|
else:
|
||||||
buffer_model.append(d_cur)
|
buffer_model.append(d_cur)
|
||||||
|
|
||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
|
||||||
#under Apache 2 license
|
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||||
|
# under Apache 2 license
|
||||||
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
@ -984,32 +995,32 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
|
|
||||||
d_cur = (x_cur - denoised) / t_cur
|
d_cur = (x_cur - denoised) / t_cur
|
||||||
|
|
||||||
order = min(max_order, i+1)
|
order = min(max_order, i + 1)
|
||||||
if order == 1: # First Euler step.
|
if order == 1: # First Euler step.
|
||||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||||
elif order == 2: # Use one history point.
|
elif order == 2: # Use one history point.
|
||||||
h_n = (t_next - t_cur)
|
h_n = (t_next - t_cur)
|
||||||
h_n_1 = (t_cur - t_steps[i-1])
|
h_n_1 = (t_cur - t_steps[i - 1])
|
||||||
coeff1 = (2 + (h_n / h_n_1)) / 2
|
coeff1 = (2 + (h_n / h_n_1)) / 2
|
||||||
coeff2 = -(h_n / h_n_1) / 2
|
coeff2 = -(h_n / h_n_1) / 2
|
||||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
|
||||||
elif order == 3: # Use two history points.
|
elif order == 3: # Use two history points.
|
||||||
h_n = (t_next - t_cur)
|
h_n = (t_next - t_cur)
|
||||||
h_n_1 = (t_cur - t_steps[i-1])
|
h_n_1 = (t_cur - t_steps[i - 1])
|
||||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
|
||||||
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
|
||||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
|
||||||
coeff3 = temp * h_n_1 / h_n_2
|
coeff3 = temp * h_n_1 / h_n_2
|
||||||
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
|
||||||
elif order == 4: # Use three history points.
|
elif order == 4: # Use three history points.
|
||||||
h_n = (t_next - t_cur)
|
h_n = (t_next - t_cur)
|
||||||
h_n_1 = (t_cur - t_steps[i-1])
|
h_n_1 = (t_cur - t_steps[i - 1])
|
||||||
h_n_2 = (t_steps[i-1] - t_steps[i-2])
|
h_n_2 = (t_steps[i - 1] - t_steps[i - 2])
|
||||||
h_n_3 = (t_steps[i-2] - t_steps[i-3])
|
h_n_3 = (t_steps[i - 2] - t_steps[i - 3])
|
||||||
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
|
||||||
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
|
||||||
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
* (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
|
||||||
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
|
||||||
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
|
||||||
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
|
||||||
@ -1018,15 +1029,16 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
|
|
||||||
if len(buffer_model) == max_order - 1:
|
if len(buffer_model) == max_order - 1:
|
||||||
for k in range(max_order - 2):
|
for k in range(max_order - 2):
|
||||||
buffer_model[k] = buffer_model[k+1]
|
buffer_model[k] = buffer_model[k + 1]
|
||||||
buffer_model[-1] = d_cur.detach()
|
buffer_model[-1] = d_cur.detach()
|
||||||
else:
|
else:
|
||||||
buffer_model.append(d_cur.detach())
|
buffer_model.append(d_cur.detach())
|
||||||
|
|
||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
|
||||||
#under Apache 2 license
|
# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
|
||||||
|
# under Apache 2 license
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
@ -1050,36 +1062,38 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
|||||||
|
|
||||||
d_cur = (x_cur - denoised) / t_cur
|
d_cur = (x_cur - denoised) / t_cur
|
||||||
|
|
||||||
order = min(max_order, i+1)
|
order = min(max_order, i + 1)
|
||||||
if t_next <= 0:
|
if t_next <= 0:
|
||||||
order = 1
|
order = 1
|
||||||
|
|
||||||
if order == 1: # First Euler step.
|
if order == 1: # First Euler step.
|
||||||
x_next = x_cur + (t_next - t_cur) * d_cur
|
x_next = x_cur + (t_next - t_cur) * d_cur
|
||||||
elif order == 2: # Use one history point.
|
elif order == 2: # Use one history point.
|
||||||
coeff_cur, coeff_prev1 = coeff_list[i]
|
coeff_cur, coeff_prev1 = coeff_list[i]
|
||||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
|
||||||
elif order == 3: # Use two history points.
|
elif order == 3: # Use two history points.
|
||||||
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
|
||||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
|
||||||
elif order == 4: # Use three history points.
|
elif order == 4: # Use three history points.
|
||||||
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
|
||||||
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
|
||||||
|
|
||||||
if len(buffer_model) == max_order - 1:
|
if len(buffer_model) == max_order - 1:
|
||||||
for k in range(max_order - 2):
|
for k in range(max_order - 2):
|
||||||
buffer_model[k] = buffer_model[k+1]
|
buffer_model[k] = buffer_model[k + 1]
|
||||||
buffer_model[-1] = d_cur.detach()
|
buffer_model[-1] = d_cur.detach()
|
||||||
else:
|
else:
|
||||||
buffer_model.append(d_cur.detach())
|
buffer_model.append(d_cur.detach())
|
||||||
|
|
||||||
return x_next
|
return x_next
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
temp[0] = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
@ -1099,6 +1113,7 @@ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
x = denoised + d * sigmas[i + 1]
|
x = denoised + d * sigmas[i + 1]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
@ -1106,6 +1121,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
temp[0] = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
@ -1126,6 +1142,8 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
|
|||||||
if sigmas[i + 1] > 0:
|
if sigmas[i + 1] > 0:
|
||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||||
@ -1133,12 +1151,13 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
|
|||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
|
||||||
temp = [0]
|
temp = [0]
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
temp[0] = args["uncond_denoised"]
|
temp[0] = args["uncond_denoised"]
|
||||||
return args["denoised"]
|
return args["denoised"]
|
||||||
|
|
||||||
model_options = extra_args.get("model_options", {}).copy()
|
model_options = extra_args.get("model_options", {}).copy()
|
||||||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
|
|||||||
52
comfy_extras/nodes/nodes_torch_compile.py
Normal file
52
comfy_extras/nodes/nodes_torch_compile.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCompileModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
||||||
|
"fullgraph": ("BOOLEAN", {"default": False}),
|
||||||
|
"dynamic": ("BOOLEAN", {"default": False}),
|
||||||
|
"backend": ("STRING", {"default": "inductor"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
||||||
|
if object_patch is None:
|
||||||
|
object_patch = DIFFUSION_MODEL
|
||||||
|
compile_kwargs = {
|
||||||
|
"fullgraph": fullgraph,
|
||||||
|
"dynamic": dynamic,
|
||||||
|
"backend": backend
|
||||||
|
}
|
||||||
|
if isinstance(model, ModelPatcher):
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||||
|
return (m,)
|
||||||
|
elif isinstance(model, torch.nn.Module):
|
||||||
|
return torch.compile(model=model, **compile_kwargs),
|
||||||
|
else:
|
||||||
|
logging.warning("Encountered a model that cannot be compiled")
|
||||||
|
return model,
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TorchCompileModel": TorchCompileModel,
|
||||||
|
}
|
||||||
@ -1,21 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
class TorchCompileModel:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def patch(self, model):
|
|
||||||
m = model.clone()
|
|
||||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"TorchCompileModel": TorchCompileModel,
|
|
||||||
}
|
|
||||||
@ -118,7 +118,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
|
|||||||
for _ in range(max_retries):
|
for _ in range(max_retries):
|
||||||
try:
|
try:
|
||||||
async with session.get(url, timeout=1) as response:
|
async with session.get(url, timeout=1) as response:
|
||||||
if response.status == 200 and await response.text() == "OK":
|
if response.status == 200:
|
||||||
return True
|
return True
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class _ProgressHandler(ServerStub):
|
|||||||
self.tuples.append((event, data, sid))
|
self.tuples.append((event, data, sid))
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class ComfyClient:
|
||||||
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
|
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
|
||||||
self.embedded_client = embedded_client
|
self.embedded_client = embedded_client
|
||||||
self.progress_handler = progress_handler
|
self.progress_handler = progress_handler
|
||||||
@ -105,7 +105,7 @@ class TestExecution:
|
|||||||
(0,),
|
(0,),
|
||||||
(100,),
|
(100,),
|
||||||
])
|
])
|
||||||
async def client(self, request) -> Client:
|
async def client(self, request) -> ComfyClient:
|
||||||
from comfy.cmd.execution import nodes
|
from comfy.cmd.execution import nodes
|
||||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
@ -115,13 +115,13 @@ class TestExecution:
|
|||||||
configuration.cache_lru = lru_size
|
configuration.cache_lru = lru_size
|
||||||
progress_handler = _ProgressHandler()
|
progress_handler = _ProgressHandler()
|
||||||
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
|
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
|
||||||
yield Client(embedded_client, progress_handler)
|
yield ComfyClient(embedded_client, progress_handler)
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
def builder(self, request):
|
def builder(self, request):
|
||||||
yield GraphBuilder(prefix=request.node.name)
|
yield GraphBuilder(prefix=request.node.name)
|
||||||
|
|
||||||
async def test_lazy_input(self, client: Client, builder: GraphBuilder):
|
async def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -138,7 +138,7 @@ class TestExecution:
|
|||||||
assert result.did_run(mask)
|
assert result.did_run(mask)
|
||||||
assert result.did_run(lazy_mix)
|
assert result.did_run(lazy_mix)
|
||||||
|
|
||||||
async def test_full_cache(self, client: Client, builder: GraphBuilder):
|
async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -152,7 +152,7 @@ class TestExecution:
|
|||||||
for node_id, node in g.nodes.items():
|
for node_id, node in g.nodes.items():
|
||||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||||
|
|
||||||
async def test_partial_cache(self, client: Client, builder: GraphBuilder):
|
async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -167,7 +167,7 @@ class TestExecution:
|
|||||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||||
|
|
||||||
async def test_error(self, client: Client, builder: GraphBuilder):
|
async def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
# Different size of the two images
|
# Different size of the two images
|
||||||
@ -188,7 +188,7 @@ class TestExecution:
|
|||||||
("foo", True),
|
("foo", True),
|
||||||
(5.0, False),
|
(5.0, False),
|
||||||
])
|
])
|
||||||
async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder):
|
async def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
||||||
g.node("SaveImage", images=validation1.out(0))
|
g.node("SaveImage", images=validation1.out(0))
|
||||||
@ -203,7 +203,7 @@ class TestExecution:
|
|||||||
("StubInt", 5),
|
("StubInt", 5),
|
||||||
("StubFloat", 5.0)
|
("StubFloat", 5.0)
|
||||||
])
|
])
|
||||||
async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder):
|
async def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
stub = g.node(test_type, value=test_value)
|
stub = g.node(test_type, value=test_value)
|
||||||
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
||||||
@ -216,7 +216,7 @@ class TestExecution:
|
|||||||
("StubInt", 5, True),
|
("StubInt", 5, True),
|
||||||
("StubFloat", 5.0, False)
|
("StubFloat", 5.0, False)
|
||||||
])
|
])
|
||||||
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
stub = g.node(test_type, value=test_value)
|
stub = g.node(test_type, value=test_value)
|
||||||
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
||||||
@ -232,7 +232,7 @@ class TestExecution:
|
|||||||
("StubInt", 5, True),
|
("StubInt", 5, True),
|
||||||
("StubFloat", 5.0, False)
|
("StubFloat", 5.0, False)
|
||||||
])
|
])
|
||||||
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
stub = g.node(test_type, value=test_value)
|
stub = g.node(test_type, value=test_value)
|
||||||
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
||||||
@ -248,7 +248,7 @@ class TestExecution:
|
|||||||
("StubInt", 5, True),
|
("StubInt", 5, True),
|
||||||
("StubFloat", 5.0, False)
|
("StubFloat", 5.0, False)
|
||||||
])
|
])
|
||||||
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
stub = g.node(test_type, value=test_value)
|
stub = g.node(test_type, value=test_value)
|
||||||
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
||||||
@ -260,7 +260,7 @@ class TestExecution:
|
|||||||
else:
|
else:
|
||||||
await client.run(g)
|
await client.run(g)
|
||||||
|
|
||||||
async def test_cycle_error(self, client: Client, builder: GraphBuilder):
|
async def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -274,7 +274,7 @@ class TestExecution:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await client.run(g)
|
await client.run(g)
|
||||||
|
|
||||||
async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder):
|
async def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -289,7 +289,7 @@ class TestExecution:
|
|||||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||||
|
|
||||||
async def test_custom_is_changed(self, client: Client, builder: GraphBuilder):
|
async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
save = g.node("SaveImage")
|
save = g.node("SaveImage")
|
||||||
@ -309,7 +309,7 @@ class TestExecution:
|
|||||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||||
|
|
||||||
async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder):
|
async def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -323,7 +323,7 @@ class TestExecution:
|
|||||||
expected = 255 // 4
|
expected = 255 // 4
|
||||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||||
|
|
||||||
async def test_for_loop(self, client: Client, builder: GraphBuilder):
|
async def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
iterations = 4
|
iterations = 4
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -342,7 +342,7 @@ class TestExecution:
|
|||||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||||
assert result.did_run(is_changed)
|
assert result.did_run(is_changed)
|
||||||
|
|
||||||
async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder):
|
async def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
|
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
|
||||||
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
|
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
|
||||||
@ -361,7 +361,7 @@ class TestExecution:
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||||
|
|
||||||
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder):
|
async def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
||||||
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
|
mask = g.node("StubMask", value=val_list.out(0), height=512, width=512, batch_size=1)
|
||||||
@ -378,7 +378,7 @@ class TestExecution:
|
|||||||
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
|
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
|
||||||
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
|
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
|
||||||
|
|
||||||
async def test_missing_node_error(self, client: Client, builder: GraphBuilder):
|
async def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -397,7 +397,7 @@ class TestExecution:
|
|||||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
await client.run(g)
|
await client.run(g)
|
||||||
|
|
||||||
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
|
async def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
@ -410,9 +410,8 @@ class TestExecution:
|
|||||||
assert len(images1) == 1, "Should have 1 image"
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
|
|
||||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||||
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder):
|
async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||||
@ -433,7 +432,7 @@ class TestExecution:
|
|||||||
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
# This tests that nodes with OUTPUT_IS_LIST function correctly when they receive an ExecutionBlocker
|
||||||
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
# as input. We also test that when that list (containing an ExecutionBlocker) is passed to a node,
|
||||||
# only that one entry in the list is blocked.
|
# only that one entry in the list is blocked.
|
||||||
def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_execution_block_list_output(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -449,7 +448,7 @@ class TestExecution:
|
|||||||
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
list_output = g.node("TestMakeListNode", value1=blocker.out(0))
|
||||||
output = g.node("PreviewImage", images=list_output.out(0))
|
output = g.node("PreviewImage", images=list_output.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = await client.run(g)
|
||||||
assert result.did_run(output), "The execution should have run"
|
assert result.did_run(output), "The execution should have run"
|
||||||
images = result.get_images(output)
|
images = result.get_images(output)
|
||||||
assert len(images) == 2, "Should have 2 images"
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
### 🗻 This file is created through the spirit of Mount Fuji at its peak
|
||||||
# TODO(yoland): clean up this after I get back down
|
# TODO(yoland): clean up this after I get back down
|
||||||
import pytest
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import folder_paths
|
import pytest
|
||||||
|
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir():
|
||||||
@ -19,21 +21,25 @@ def test_get_directory_by_type():
|
|||||||
assert folder_paths.get_directory_by_type("output") == test_dir
|
assert folder_paths.get_directory_by_type("output") == test_dir
|
||||||
assert folder_paths.get_directory_by_type("invalid") is None
|
assert folder_paths.get_directory_by_type("invalid") is None
|
||||||
|
|
||||||
|
|
||||||
def test_annotated_filepath():
|
def test_annotated_filepath():
|
||||||
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None)
|
||||||
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory())
|
||||||
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory())
|
||||||
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory())
|
||||||
|
|
||||||
|
|
||||||
def test_get_annotated_filepath():
|
def test_get_annotated_filepath():
|
||||||
default_dir = "/default/dir"
|
default_dir = "/default/dir"
|
||||||
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt")
|
||||||
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt")
|
||||||
|
|
||||||
|
|
||||||
def test_add_model_folder_path():
|
def test_add_model_folder_path():
|
||||||
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
folder_paths.add_model_folder_path("test_folder", "/test/path")
|
||||||
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||||
|
|
||||||
|
|
||||||
def test_recursive_search(temp_dir):
|
def test_recursive_search(temp_dir):
|
||||||
os.makedirs(os.path.join(temp_dir, "subdir"))
|
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||||
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||||
@ -43,12 +49,14 @@ def test_recursive_search(temp_dir):
|
|||||||
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||||
assert len(dirs) == 2 # temp_dir and subdir
|
assert len(dirs) == 2 # temp_dir and subdir
|
||||||
|
|
||||||
|
|
||||||
def test_filter_files_extensions():
|
def test_filter_files_extensions():
|
||||||
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||||
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||||
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"]
|
||||||
assert folder_paths.filter_files_extensions(files, []) == files
|
assert folder_paths.filter_files_extensions(files, []) == files
|
||||||
|
|
||||||
|
|
||||||
@patch("folder_paths.recursive_search")
|
@patch("folder_paths.recursive_search")
|
||||||
@patch("folder_paths.folder_names_and_paths")
|
@patch("folder_paths.folder_names_and_paths")
|
||||||
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||||
@ -56,6 +64,7 @@ def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
|||||||
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||||
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_save_image_path(temp_dir):
|
def test_get_save_image_path(temp_dir):
|
||||||
with patch("folder_paths.output_directory", temp_dir):
|
with patch("folder_paths.output_directory", temp_dir):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||||
@ -1,7 +1,10 @@
|
|||||||
import pytest
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from folder_paths import filter_files_content_types
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def file_extensions():
|
def file_extensions():
|
||||||
0
tests/unit/utils/__init__.py
Normal file
0
tests/unit/utils/__init__.py
Normal file
@ -1,10 +1,12 @@
|
|||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import Mock, patch, mock_open
|
from unittest.mock import Mock, patch, mock_open
|
||||||
|
|
||||||
from utils.extra_config import load_extra_path_config
|
import pytest
|
||||||
import folder_paths
|
import yaml
|
||||||
|
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
from comfy.extra_config import load_extra_path_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_content():
|
def mock_yaml_content():
|
||||||
@ -15,10 +17,12 @@ def mock_yaml_content():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expanded_home():
|
def mock_expanded_home():
|
||||||
return '/home/user'
|
return '/home/user'
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def yaml_config_with_appdata():
|
def yaml_config_with_appdata():
|
||||||
return """
|
return """
|
||||||
@ -27,40 +31,47 @@ def yaml_config_with_appdata():
|
|||||||
checkpoints: 'models/checkpoints'
|
checkpoints: 'models/checkpoints'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
def mock_yaml_content_appdata(yaml_config_with_appdata):
|
||||||
return yaml.safe_load(yaml_config_with_appdata)
|
return yaml.safe_load(yaml_config_with_appdata)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expandvars_appdata():
|
def mock_expandvars_appdata():
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
mock.side_effect = lambda path: path.replace('%APPDATA%', 'C:/Users/TestUser/AppData/Roaming')
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_add_model_folder_path():
|
def mock_add_model_folder_path():
|
||||||
return Mock()
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_expanduser(mock_expanded_home):
|
def mock_expanduser(mock_expanded_home):
|
||||||
def _expanduser(path):
|
def _expanduser(path):
|
||||||
if path.startswith('~/'):
|
if path.startswith('~/'):
|
||||||
return os.path.join(mock_expanded_home, path[2:])
|
return os.path.join(mock_expanded_home, path[2:])
|
||||||
return path
|
return path
|
||||||
|
|
||||||
return _expanduser
|
return _expanduser
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_yaml_safe_load(mock_yaml_content):
|
def mock_yaml_safe_load(mock_yaml_content):
|
||||||
return Mock(return_value=mock_yaml_content)
|
return Mock(return_value=mock_yaml_content)
|
||||||
|
|
||||||
|
|
||||||
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
@patch('builtins.open', new_callable=mock_open, read_data="dummy file content")
|
||||||
def test_load_extra_model_paths_expands_userpath(
|
def test_load_extra_model_paths_expands_userpath(
|
||||||
mock_file,
|
mock_file,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
mock_add_model_folder_path,
|
mock_add_model_folder_path,
|
||||||
mock_expanduser,
|
mock_expanduser,
|
||||||
mock_yaml_safe_load,
|
mock_yaml_safe_load,
|
||||||
mock_expanded_home
|
mock_expanded_home
|
||||||
):
|
):
|
||||||
# Attach mocks used by load_extra_path_config
|
# Attach mocks used by load_extra_path_config
|
||||||
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path)
|
||||||
@ -86,14 +97,15 @@ def test_load_extra_model_paths_expands_userpath(
|
|||||||
# Check if open was called with the correct file path
|
# Check if open was called with the correct file path
|
||||||
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
|
||||||
|
|
||||||
|
|
||||||
@patch('builtins.open', new_callable=mock_open)
|
@patch('builtins.open', new_callable=mock_open)
|
||||||
def test_load_extra_model_paths_expands_appdata(
|
def test_load_extra_model_paths_expands_appdata(
|
||||||
mock_file,
|
mock_file,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
mock_add_model_folder_path,
|
mock_add_model_folder_path,
|
||||||
mock_expandvars_appdata,
|
mock_expandvars_appdata,
|
||||||
yaml_config_with_appdata,
|
yaml_config_with_appdata,
|
||||||
mock_yaml_content_appdata
|
mock_yaml_content_appdata
|
||||||
):
|
):
|
||||||
# Set the mock_file to return yaml with appdata as a variable
|
# Set the mock_file to return yaml with appdata as a variable
|
||||||
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
mock_file.return_value.read.return_value = yaml_config_with_appdata
|
||||||
Loading…
Reference in New Issue
Block a user