mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 07:22:36 +08:00
Merge branch 'comfyanonymous:master' into feat/is_change_object_storage
This commit is contained in:
commit
ce792da86a
31
.github/workflows/test-build.yml
vendored
Normal file
31
.github/workflows/test-build.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: Build package
|
||||||
|
|
||||||
|
#
|
||||||
|
# This workflow is a test of the python package build.
|
||||||
|
# Install Python dependencies across different Python versions.
|
||||||
|
#
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- "requirements.txt"
|
||||||
|
- ".github/workflows/test-build.yml"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
@ -706,3 +706,34 @@ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disab
|
|||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
||||||
|
|
||||||
|
|
||||||
|
def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||||
|
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
||||||
|
alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
|
||||||
|
alpha = (alpha_cumprod / alpha_cumprod_prev)
|
||||||
|
|
||||||
|
mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
|
||||||
|
if sigma_prev > 0:
|
||||||
|
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
|
||||||
|
if sigmas[i + 1] != 0:
|
||||||
|
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
|
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
|
|
||||||
class LatentFormat:
|
class LatentFormat:
|
||||||
|
scale_factor = 1.0
|
||||||
|
latent_rgb_factors = None
|
||||||
|
taesd_decoder_name = None
|
||||||
|
|
||||||
def process_in(self, latent):
|
def process_in(self, latent):
|
||||||
return latent * self.scale_factor
|
return latent * self.scale_factor
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,6 @@ class DDIMSampler(object):
|
|||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||||
|
|
||||||
self.register_buffer('betas', to_torch(self.model.betas))
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
@ -195,7 +194,7 @@ class DDIMSampler(object):
|
|||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
||||||
device = self.model.betas.device
|
device = self.model.alphas_cumprod.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
|
|||||||
@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel):
|
|||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
@ -206,5 +206,5 @@ class SDXL(BaseModel):
|
|||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([target_height])))
|
out.append(self.embedder(torch.Tensor([target_height])))
|
||||||
out.append(self.embedder(torch.Tensor([target_width])))
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|||||||
@ -165,6 +165,9 @@ try:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
if is_intel_xpu():
|
||||||
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -478,6 +481,23 @@ def get_autocast_device(dev):
|
|||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
|
device_supports_cast = False
|
||||||
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
|
device_supports_cast = True
|
||||||
|
elif tensor.dtype == torch.bfloat16:
|
||||||
|
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
||||||
|
device_supports_cast = True
|
||||||
|
|
||||||
|
if device_supports_cast:
|
||||||
|
if copy:
|
||||||
|
if tensor.device == device:
|
||||||
|
return tensor.to(dtype, copy=copy)
|
||||||
|
return tensor.to(device, copy=copy).to(dtype)
|
||||||
|
else:
|
||||||
|
return tensor.to(device).to(dtype)
|
||||||
|
else:
|
||||||
|
return tensor.to(dtype).to(device, copy=copy)
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||||
@ -154,7 +155,7 @@ class ModelPatcher:
|
|||||||
self.backup[key] = weight.to(self.offload_device)
|
self.backup[key] = weight.to(self.offload_device)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = weight.float().to(device_to, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
@ -185,15 +186,15 @@ class ModelPatcher:
|
|||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0].float().to(weight.device)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = v[1].float().to(weight.device)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
alpha *= v[2] / mat2.shape[0]
|
alpha *= v[2] / mat2.shape[0]
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = v[3].float().to(weight.device)
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
try:
|
try:
|
||||||
@ -212,18 +213,23 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(w1_a.float(), w1_b.float())
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w1 = w1.float().to(weight.device)
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device))
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device))
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
w2 = w2.float().to(weight.device)
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@ -244,11 +250,20 @@ class ModelPatcher:
|
|||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: #cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device))
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device))
|
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from .ldm.models.diffusion.ddim import DDIMSampler
|
|||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
import math
|
import math
|
||||||
from comfy import model_base
|
from comfy import model_base
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
return abs(a*b) // math.gcd(a, b)
|
return abs(a*b) // math.gcd(a, b)
|
||||||
@ -255,6 +256,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
else:
|
else:
|
||||||
transformer_options["patches"] = patches
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
if 'model_function_wrapper' in model_options:
|
||||||
@ -537,7 +539,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
|
|
||||||
if adm_out is not None:
|
if adm_out is not None:
|
||||||
x[1] = x[1].copy()
|
x[1] = x[1].copy()
|
||||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device)
|
x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
|
||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
@ -546,7 +548,7 @@ class KSampler:
|
|||||||
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.empty_tokens = [[49406] + [49407] * 76]
|
self.empty_tokens = [[49406] + [49407] * 76]
|
||||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
|
self.enable_attention_masks = False
|
||||||
|
|
||||||
self.layer_norm_hidden_state = True
|
self.layer_norm_hidden_state = True
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
attention_mask = None
|
||||||
|
if self.enable_attention_masks:
|
||||||
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||||
|
for x in range(attention_mask.shape[0]):
|
||||||
|
for y in range(attention_mask.shape[1]):
|
||||||
|
attention_mask[x, y] = 1
|
||||||
|
if tokens[x, y] == max_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import math
|
|||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -346,6 +348,13 @@ def bislerp(samples, width, height):
|
|||||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def lanczos(samples, width, height):
|
||||||
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||||
|
result = torch.stack(images)
|
||||||
|
return result
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[3]
|
||||||
@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop):
|
|||||||
|
|
||||||
if upscale_method == "bislerp":
|
if upscale_method == "bislerp":
|
||||||
return bislerp(s, width, height)
|
return bislerp(s, width, height)
|
||||||
|
elif upscale_method == "lanczos":
|
||||||
|
return lanczos(s, width, height)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class ModelSubtract:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2, multiplier):
|
def merge(self, model1, model2, multiplier):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
@ -55,7 +55,7 @@ class ModelAdd:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2):
|
def merge(self, model1, model2):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class Sharpen:
|
|||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
class ImageScaleToTotalPixels:
|
class ImageScaleToTotalPixels:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -54,7 +54,13 @@ class Example:
|
|||||||
"step": 64, #Slider's step
|
"step": 64, #Slider's step
|
||||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||||
}),
|
}),
|
||||||
"float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}),
|
"float_field": ("FLOAT", {
|
||||||
|
"default": 1.0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 10.0,
|
||||||
|
"step": 0.01,
|
||||||
|
"round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||||
|
"display": "number"}),
|
||||||
"print_to_screen": (["enable", "disable"],),
|
"print_to_screen": (["enable", "disable"],),
|
||||||
"string_field": ("STRING", {
|
"string_field": ("STRING", {
|
||||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||||
|
|||||||
@ -53,7 +53,9 @@ def get_previewer(device, latent_format):
|
|||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
taesd_decoder_path = None
|
||||||
|
if latent_format.taesd_decoder_name is not None:
|
||||||
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
@ -68,7 +70,8 @@ def get_previewer(device, latent_format):
|
|||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
if latent_format.latent_rgb_factors is not None:
|
||||||
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
16
nodes.py
16
nodes.py
@ -543,8 +543,8 @@ class LoraLoader:
|
|||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"clip": ("CLIP", ),
|
"clip": ("CLIP", ),
|
||||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP")
|
RETURN_TYPES = ("MODEL", "CLIP")
|
||||||
FUNCTION = "load_lora"
|
FUNCTION = "load_lora"
|
||||||
@ -889,8 +889,8 @@ class EmptyLatentImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
"height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
@ -1217,7 +1217,7 @@ class KSampler:
|
|||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
@ -1243,7 +1243,7 @@ class KSamplerAdvanced:
|
|||||||
"add_noise": (["enable", "disable"], ),
|
"add_noise": (["enable", "disable"], ),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
@ -1423,7 +1423,7 @@ class LoadImageMask:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1444,7 +1444,7 @@ class ImageScale:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ImageScaleBy:
|
class ImageScaleBy:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
5
pytest.ini
Normal file
5
pytest.ini
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
|
testpaths = tests
|
||||||
|
addopts = -s
|
||||||
@ -132,12 +132,12 @@ class PromptServer():
|
|||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
files = glob.glob(os.path.join(
|
files = glob.glob(os.path.join(
|
||||||
self.web_root, 'extensions/**/*.js'), recursive=True)
|
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||||
|
|
||||||
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||||||
|
|
||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True)
|
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
||||||
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
||||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||||
|
|
||||||
|
|||||||
29
tests/README.md
Normal file
29
tests/README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Automated Testing
|
||||||
|
|
||||||
|
## Running tests locally
|
||||||
|
|
||||||
|
Additional requirements for running tests:
|
||||||
|
```
|
||||||
|
pip install pytest
|
||||||
|
pip install websocket-client==1.6.1
|
||||||
|
opencv-python==4.6.0.66
|
||||||
|
scikit-image==0.21.0
|
||||||
|
```
|
||||||
|
Run inference tests:
|
||||||
|
```
|
||||||
|
pytest tests/inference
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quality regression test
|
||||||
|
Compares images in 2 directories to ensure they are the same
|
||||||
|
|
||||||
|
1) Run an inference test to save a directory of "ground truth" images
|
||||||
|
```
|
||||||
|
pytest tests/inference --output_dir tests/inference/baseline
|
||||||
|
```
|
||||||
|
2) Make code edits
|
||||||
|
|
||||||
|
3) Run inference and quality comparison tests
|
||||||
|
```
|
||||||
|
pytest
|
||||||
|
```
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
41
tests/compare/conftest.py
Normal file
41
tests/compare/conftest.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Command line arguments for pytest
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images')
|
||||||
|
parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test')
|
||||||
|
parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics')
|
||||||
|
parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images')
|
||||||
|
|
||||||
|
# This initializes args at the beginning of the test session
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def args_pytest(pytestconfig):
|
||||||
|
args = {}
|
||||||
|
args['baseline_dir'] = pytestconfig.getoption('baseline_dir')
|
||||||
|
args['test_dir'] = pytestconfig.getoption('test_dir')
|
||||||
|
args['metrics_file'] = pytestconfig.getoption('metrics_file')
|
||||||
|
args['img_output_dir'] = pytestconfig.getoption('img_output_dir')
|
||||||
|
|
||||||
|
# Initialize metrics file
|
||||||
|
with open(args['metrics_file'], 'a') as f:
|
||||||
|
# if file is empty, write header
|
||||||
|
if os.stat(args['metrics_file']).st_size == 0:
|
||||||
|
f.write("| date | run | file | status | value | \n")
|
||||||
|
f.write("| --- | --- | --- | --- | --- | \n")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def gather_file_basenames(directory: str):
|
||||||
|
files = []
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
files.append(file)
|
||||||
|
return files
|
||||||
|
|
||||||
|
# Creates the list of baseline file names to use as a fixture
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "baseline_fname" in metafunc.fixturenames:
|
||||||
|
baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir"))
|
||||||
|
metafunc.parametrize("baseline_fname", baseline_fnames)
|
||||||
195
tests/compare/test_quality.py
Normal file
195
tests/compare/test_quality.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import pytest
|
||||||
|
from pytest import fixture
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
|
from cv2 import imread, cvtColor, COLOR_BGR2RGB
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This test suite compares images in 2 directories by file name
|
||||||
|
The directories are specified by the command line arguments --baseline_dir and --test_dir
|
||||||
|
|
||||||
|
"""
|
||||||
|
# ssim: Structural Similarity Index
|
||||||
|
# Returns a tuple of (ssim, diff_image)
|
||||||
|
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
|
||||||
|
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
|
||||||
|
# rescale the difference image to 0-255 range
|
||||||
|
diff = (diff * 255).astype("uint8")
|
||||||
|
return score, diff
|
||||||
|
|
||||||
|
# Metrics must return a tuple of (score, diff_image)
|
||||||
|
METRICS = {"ssim": ssim_score}
|
||||||
|
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompareImageMetrics:
|
||||||
|
@fixture(scope="class")
|
||||||
|
def test_file_names(self, args_pytest):
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
fnames = self.gather_file_basenames(test_dir)
|
||||||
|
yield fnames
|
||||||
|
del fnames
|
||||||
|
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def teardown(self, args_pytest):
|
||||||
|
yield
|
||||||
|
# Runs after all tests are complete
|
||||||
|
# Aggregate output files into a grid of images
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
metrics_file = args_pytest['metrics_file']
|
||||||
|
|
||||||
|
grid_dir = os.path.join(img_output_dir, "grid")
|
||||||
|
os.makedirs(grid_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for metric_dir in METRICS.keys():
|
||||||
|
metric_path = os.path.join(img_output_dir, metric_dir)
|
||||||
|
for file in os.listdir(metric_path):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
score = self.lookup_score_from_fname(file, metrics_file)
|
||||||
|
image_file_list = []
|
||||||
|
image_file_list.append([
|
||||||
|
os.path.join(baseline_dir, file),
|
||||||
|
os.path.join(test_dir, file),
|
||||||
|
os.path.join(metric_path, file)
|
||||||
|
])
|
||||||
|
# Create grid
|
||||||
|
image_list = [[Image.open(file) for file in files] for files in image_file_list]
|
||||||
|
grid = self.image_grid(image_list)
|
||||||
|
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
|
||||||
|
|
||||||
|
# Tests run for each baseline file name
|
||||||
|
@fixture()
|
||||||
|
def fname(self, baseline_fname):
|
||||||
|
yield baseline_fname
|
||||||
|
del baseline_fname
|
||||||
|
|
||||||
|
def test_directories_not_empty(self, args_pytest):
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
|
||||||
|
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
|
||||||
|
|
||||||
|
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
|
||||||
|
# Check that all files in baseline_dir have a file in test_dir with matching metadata
|
||||||
|
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
|
||||||
|
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
|
||||||
|
file_match = self.find_file_match(baseline_file_path, file_paths)
|
||||||
|
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
|
||||||
|
|
||||||
|
# For a baseline image file, finds the corresponding file name in test_dir and
|
||||||
|
# compares the images using the metrics in METRICS
|
||||||
|
@pytest.mark.parametrize("metric", METRICS.keys())
|
||||||
|
def test_pipeline_compare(
|
||||||
|
self,
|
||||||
|
args_pytest,
|
||||||
|
fname,
|
||||||
|
test_file_names,
|
||||||
|
metric,
|
||||||
|
):
|
||||||
|
baseline_dir = args_pytest['baseline_dir']
|
||||||
|
test_dir = args_pytest['test_dir']
|
||||||
|
metrics_output_file = args_pytest['metrics_file']
|
||||||
|
img_output_dir = args_pytest['img_output_dir']
|
||||||
|
|
||||||
|
baseline_file_path = os.path.join(baseline_dir, fname)
|
||||||
|
|
||||||
|
# Find file match
|
||||||
|
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
|
||||||
|
test_file = self.find_file_match(baseline_file_path, file_paths)
|
||||||
|
|
||||||
|
# Run metrics
|
||||||
|
sample_baseline = self.read_img(baseline_file_path)
|
||||||
|
sample_secondary = self.read_img(test_file)
|
||||||
|
|
||||||
|
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
|
||||||
|
metric_status = score > METRICS_PASS_THRESHOLD[metric]
|
||||||
|
|
||||||
|
# Save metric values
|
||||||
|
with open(metrics_output_file, 'a') as f:
|
||||||
|
run_info = os.path.splitext(fname)[0]
|
||||||
|
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
|
||||||
|
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
|
||||||
|
|
||||||
|
# Save metric image
|
||||||
|
metric_img_dir = os.path.join(img_output_dir, metric)
|
||||||
|
os.makedirs(metric_img_dir, exist_ok=True)
|
||||||
|
output_filename = f'{fname}'
|
||||||
|
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
|
||||||
|
|
||||||
|
assert score > METRICS_PASS_THRESHOLD[metric]
|
||||||
|
|
||||||
|
def read_img(self, filename: str) -> np.ndarray:
|
||||||
|
cvImg = imread(filename)
|
||||||
|
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
||||||
|
return cvImg
|
||||||
|
|
||||||
|
def image_grid(self, img_list: list[list[Image.Image]]):
|
||||||
|
# imgs is a 2D list of images
|
||||||
|
# Assumes the input images are a rectangular grid of equal sized images
|
||||||
|
rows = len(img_list)
|
||||||
|
cols = len(img_list[0])
|
||||||
|
|
||||||
|
w, h = img_list[0][0].size
|
||||||
|
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||||
|
|
||||||
|
for i, row in enumerate(img_list):
|
||||||
|
for j, img in enumerate(row):
|
||||||
|
grid.paste(img, box=(j*w, i*h))
|
||||||
|
return grid
|
||||||
|
|
||||||
|
def lookup_score_from_fname(self,
|
||||||
|
fname: str,
|
||||||
|
metrics_output_file: str
|
||||||
|
) -> float:
|
||||||
|
fname_basestr = os.path.splitext(fname)[0]
|
||||||
|
with open(metrics_output_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if fname_basestr in line:
|
||||||
|
score = float(line.split('|')[5])
|
||||||
|
return score
|
||||||
|
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
||||||
|
|
||||||
|
def gather_file_basenames(self, directory: str):
|
||||||
|
files = []
|
||||||
|
for file in os.listdir(directory):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
files.append(file)
|
||||||
|
return files
|
||||||
|
|
||||||
|
def read_file_prompt(self, fname:str) -> str:
|
||||||
|
# Read prompt from image file metadata
|
||||||
|
img = Image.open(fname)
|
||||||
|
img.load()
|
||||||
|
return img.info['prompt']
|
||||||
|
|
||||||
|
def find_file_match(self, baseline_file: str, file_paths: List[str]):
|
||||||
|
# Find a file in file_paths with matching metadata to baseline_file
|
||||||
|
baseline_prompt = self.read_file_prompt(baseline_file)
|
||||||
|
|
||||||
|
# Do not match empty prompts
|
||||||
|
if baseline_prompt is None or baseline_prompt == "":
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find file match
|
||||||
|
# Reorder test_file_names so that the file with matching name is first
|
||||||
|
# This is an optimization because matching file names are more likely
|
||||||
|
# to have matching metadata if they were generated with the same script
|
||||||
|
basename = os.path.basename(baseline_file)
|
||||||
|
file_path_basenames = [os.path.basename(f) for f in file_paths]
|
||||||
|
if basename in file_path_basenames:
|
||||||
|
match_index = file_path_basenames.index(basename)
|
||||||
|
file_paths.insert(0, file_paths.pop(match_index))
|
||||||
|
|
||||||
|
for f in file_paths:
|
||||||
|
test_file_prompt = self.read_file_prompt(f)
|
||||||
|
if baseline_prompt == test_file_prompt:
|
||||||
|
return f
|
||||||
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Command line arguments for pytest
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
|
||||||
|
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||||
|
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
|
|
||||||
|
# This initializes args at the beginning of the test session
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def args_pytest(pytestconfig):
|
||||||
|
args = {}
|
||||||
|
args['output_dir'] = pytestconfig.getoption('output_dir')
|
||||||
|
args['listen'] = pytestconfig.getoption('listen')
|
||||||
|
args['port'] = pytestconfig.getoption('port')
|
||||||
|
|
||||||
|
os.makedirs(args['output_dir'], exist_ok=True)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items):
|
||||||
|
# Modifies items so tests run in the correct order
|
||||||
|
|
||||||
|
LAST_TESTS = ['test_quality']
|
||||||
|
|
||||||
|
# Move the last items to the end
|
||||||
|
last_items = []
|
||||||
|
for test_name in LAST_TESTS:
|
||||||
|
for item in items.copy():
|
||||||
|
print(item.module.__name__, item)
|
||||||
|
if item.module.__name__ == test_name:
|
||||||
|
last_items.append(item)
|
||||||
|
items.remove(item)
|
||||||
|
|
||||||
|
items.extend(last_items)
|
||||||
0
tests/inference/__init__.py
Normal file
0
tests/inference/__init__.py
Normal file
144
tests/inference/graphs/default_graph_sdxl1_0.json
Normal file
144
tests/inference/graphs/default_graph_sdxl1_0.json
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
{
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "sd_xl_base_1.0.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple"
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptyLatentImage"
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "a photo of a cat",
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode"
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "enable",
|
||||||
|
"noise_seed": 42,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 7.5,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"start_at_step": 0,
|
||||||
|
"end_at_step": 32,
|
||||||
|
"return_with_leftover_noise": "enable",
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"15",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced"
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"14",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode"
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "test_inference",
|
||||||
|
"images": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage"
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "disable",
|
||||||
|
"noise_seed": 42,
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 7.5,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"start_at_step": 32,
|
||||||
|
"end_at_step": 10000,
|
||||||
|
"return_with_leftover_noise": "disable",
|
||||||
|
"model": [
|
||||||
|
"16",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"20",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"10",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced"
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"conditioning": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ConditioningZeroOut"
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "sd_xl_refiner_1.0.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "CheckpointLoaderSimple"
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "a photo of a cat",
|
||||||
|
"clip": [
|
||||||
|
"16",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode"
|
||||||
|
},
|
||||||
|
"20": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "",
|
||||||
|
"clip": [
|
||||||
|
"16",
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode"
|
||||||
|
}
|
||||||
|
}
|
||||||
239
tests/inference/test_inference.py
Normal file
239
tests/inference/test_inference.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from io import BytesIO
|
||||||
|
from urllib import request
|
||||||
|
import numpy
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import pytest
|
||||||
|
from pytest import fixture
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
|
||||||
|
from comfy.samplers import KSampler
|
||||||
|
|
||||||
|
"""
|
||||||
|
These tests generate and save images through a range of parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ComfyGraph:
|
||||||
|
def __init__(self,
|
||||||
|
graph: dict,
|
||||||
|
sampler_nodes: list[str],
|
||||||
|
):
|
||||||
|
self.graph = graph
|
||||||
|
self.sampler_nodes = sampler_nodes
|
||||||
|
|
||||||
|
def set_prompt(self, prompt, negative_prompt=None):
|
||||||
|
# Sets the prompt for the sampler nodes (eg. base and refiner)
|
||||||
|
for node in self.sampler_nodes:
|
||||||
|
prompt_node = self.graph[node]['inputs']['positive'][0]
|
||||||
|
self.graph[prompt_node]['inputs']['text'] = prompt
|
||||||
|
if negative_prompt:
|
||||||
|
negative_prompt_node = self.graph[node]['inputs']['negative'][0]
|
||||||
|
self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt
|
||||||
|
|
||||||
|
def set_sampler_name(self, sampler_name:str, ):
|
||||||
|
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||||
|
for node in self.sampler_nodes:
|
||||||
|
self.graph[node]['inputs']['sampler_name'] = sampler_name
|
||||||
|
|
||||||
|
def set_scheduler(self, scheduler:str):
|
||||||
|
# sets the sampler name for the sampler nodes (eg. base and refiner)
|
||||||
|
for node in self.sampler_nodes:
|
||||||
|
self.graph[node]['inputs']['scheduler'] = scheduler
|
||||||
|
|
||||||
|
def set_filename_prefix(self, prefix:str):
|
||||||
|
# sets the filename prefix for the save nodes
|
||||||
|
for node in self.graph:
|
||||||
|
if self.graph[node]['class_type'] == 'SaveImage':
|
||||||
|
self.graph[node]['inputs']['filename_prefix'] = prefix
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyClient:
|
||||||
|
# From examples/websockets_api_example.py
|
||||||
|
|
||||||
|
def connect(self,
|
||||||
|
listen:str = '127.0.0.1',
|
||||||
|
port:Union[str,int] = 8188,
|
||||||
|
client_id: str = str(uuid.uuid4())
|
||||||
|
):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.server_address = f"{listen}:{port}"
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id))
|
||||||
|
self.ws = ws
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt):
|
||||||
|
p = {"prompt": prompt, "client_id": self.client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data)
|
||||||
|
return json.loads(urllib.request.urlopen(req).read())
|
||||||
|
|
||||||
|
def get_image(self, filename, subfolder, folder_type):
|
||||||
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||||
|
url_values = urllib.parse.urlencode(data)
|
||||||
|
with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response:
|
||||||
|
return response.read()
|
||||||
|
|
||||||
|
def get_history(self, prompt_id):
|
||||||
|
with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response:
|
||||||
|
return json.loads(response.read())
|
||||||
|
|
||||||
|
def get_images(self, graph, save=True):
|
||||||
|
prompt = graph
|
||||||
|
if not save:
|
||||||
|
# Replace save nodes with preview nodes
|
||||||
|
prompt_str = json.dumps(prompt)
|
||||||
|
prompt_str = prompt_str.replace('SaveImage', 'PreviewImage')
|
||||||
|
prompt = json.loads(prompt_str)
|
||||||
|
|
||||||
|
prompt_id = self.queue_prompt(prompt)['prompt_id']
|
||||||
|
output_images = {}
|
||||||
|
while True:
|
||||||
|
out = self.ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
|
break #Execution is done
|
||||||
|
else:
|
||||||
|
continue #previews are binary data
|
||||||
|
|
||||||
|
history = self.get_history(prompt_id)[prompt_id]
|
||||||
|
for o in history['outputs']:
|
||||||
|
for node_id in history['outputs']:
|
||||||
|
node_output = history['outputs'][node_id]
|
||||||
|
if 'images' in node_output:
|
||||||
|
images_output = []
|
||||||
|
for image in node_output['images']:
|
||||||
|
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
|
||||||
|
images_output.append(image_data)
|
||||||
|
output_images[node_id] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
#
|
||||||
|
# Initialize graphs
|
||||||
|
#
|
||||||
|
default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json'
|
||||||
|
with open(default_graph_file, 'r') as file:
|
||||||
|
default_graph = json.loads(file.read())
|
||||||
|
DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14'])
|
||||||
|
DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0]
|
||||||
|
|
||||||
|
#
|
||||||
|
# Loop through these variables
|
||||||
|
#
|
||||||
|
comfy_graph_list = [DEFAULT_COMFY_GRAPH]
|
||||||
|
comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID]
|
||||||
|
prompt_list = [
|
||||||
|
'a painting of a cat',
|
||||||
|
]
|
||||||
|
|
||||||
|
sampler_list = KSampler.SAMPLERS
|
||||||
|
scheduler_list = KSampler.SCHEDULERS
|
||||||
|
|
||||||
|
@pytest.mark.inference
|
||||||
|
@pytest.mark.parametrize("sampler", sampler_list)
|
||||||
|
@pytest.mark.parametrize("scheduler", scheduler_list)
|
||||||
|
@pytest.mark.parametrize("prompt", prompt_list)
|
||||||
|
class TestInference:
|
||||||
|
#
|
||||||
|
# Initialize server and client
|
||||||
|
#
|
||||||
|
@fixture(scope="class", autouse=True)
|
||||||
|
def _server(self, args_pytest):
|
||||||
|
# Start server
|
||||||
|
p = subprocess.Popen([
|
||||||
|
'python','main.py',
|
||||||
|
'--output-directory', args_pytest["output_dir"],
|
||||||
|
'--listen', args_pytest["listen"],
|
||||||
|
'--port', str(args_pytest["port"]),
|
||||||
|
])
|
||||||
|
yield
|
||||||
|
p.kill()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def start_client(self, listen:str, port:int):
|
||||||
|
# Start client
|
||||||
|
comfy_client = ComfyClient()
|
||||||
|
# Connect to server (with retries)
|
||||||
|
n_tries = 5
|
||||||
|
for i in range(n_tries):
|
||||||
|
time.sleep(4)
|
||||||
|
try:
|
||||||
|
comfy_client.connect(listen=listen, port=port)
|
||||||
|
except ConnectionRefusedError as e:
|
||||||
|
print(e)
|
||||||
|
print(f"({i+1}/{n_tries}) Retrying...")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return comfy_client
|
||||||
|
|
||||||
|
#
|
||||||
|
# Client and graph fixtures with server warmup
|
||||||
|
#
|
||||||
|
# Returns a "_client_graph", which is client-graph pair corresponding to an initialized server
|
||||||
|
# The "graph" is the default graph
|
||||||
|
@fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True)
|
||||||
|
def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph):
|
||||||
|
comfy_graph = request.param
|
||||||
|
|
||||||
|
# Start client
|
||||||
|
comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||||
|
|
||||||
|
# Warm up pipeline
|
||||||
|
comfy_client.get_images(graph=comfy_graph.graph, save=False)
|
||||||
|
|
||||||
|
yield comfy_client, comfy_graph
|
||||||
|
del comfy_client
|
||||||
|
del comfy_graph
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def client(self, _client_graph):
|
||||||
|
client = _client_graph[0]
|
||||||
|
yield client
|
||||||
|
|
||||||
|
@fixture
|
||||||
|
def comfy_graph(self, _client_graph):
|
||||||
|
# avoid mutating the graph
|
||||||
|
graph = deepcopy(_client_graph[1])
|
||||||
|
yield graph
|
||||||
|
|
||||||
|
def test_comfy(
|
||||||
|
self,
|
||||||
|
client,
|
||||||
|
comfy_graph,
|
||||||
|
sampler,
|
||||||
|
scheduler,
|
||||||
|
prompt,
|
||||||
|
request
|
||||||
|
):
|
||||||
|
test_info = request.node.name
|
||||||
|
comfy_graph.set_filename_prefix(test_info)
|
||||||
|
# Settings for comfy graph
|
||||||
|
comfy_graph.set_sampler_name(sampler)
|
||||||
|
comfy_graph.set_scheduler(scheduler)
|
||||||
|
comfy_graph.set_prompt(prompt)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
images = client.get_images(comfy_graph.graph)
|
||||||
|
|
||||||
|
assert len(images) != 0, "No images generated"
|
||||||
|
# assert all images are not blank
|
||||||
|
for images_output in images.values():
|
||||||
|
for image_data in images_output:
|
||||||
|
pil_image = Image.open(BytesIO(image_data))
|
||||||
|
assert numpy.array(pil_image).any() != 0, "Image is blank"
|
||||||
|
|
||||||
|
|
||||||
@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
this.title = o.title;
|
this.title = o.title;
|
||||||
this._bounding.set(o.bounding);
|
this._bounding.set(o.bounding);
|
||||||
this.color = o.color;
|
this.color = o.color;
|
||||||
this.font = o.font;
|
if (o.font_size) {
|
||||||
|
this.font_size = o.font_size;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
LGraphGroup.prototype.serialize = function() {
|
LGraphGroup.prototype.serialize = function() {
|
||||||
@ -4942,7 +4944,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
Math.round(b[3])
|
Math.round(b[3])
|
||||||
],
|
],
|
||||||
color: this.color,
|
color: this.color,
|
||||||
font: this.font
|
font_size: this.font_size
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -532,7 +532,17 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.imageRects.push([x, y, cellWidth, cellHeight]);
|
this.imageRects.push([x, y, cellWidth, cellHeight]);
|
||||||
ctx.drawImage(img, x, y, cellWidth, cellHeight);
|
|
||||||
|
let wratio = cellWidth/img.width;
|
||||||
|
let hratio = cellHeight/img.height;
|
||||||
|
var ratio = Math.min(wratio, hratio);
|
||||||
|
|
||||||
|
let imgHeight = ratio * img.height;
|
||||||
|
let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2;
|
||||||
|
let imgWidth = ratio * img.width;
|
||||||
|
let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2;
|
||||||
|
|
||||||
|
ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight);
|
||||||
ctx.filter = "none";
|
ctx.filter = "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -743,8 +753,9 @@ export class ComfyApp {
|
|||||||
// Default system copy
|
// Default system copy
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy nodes and clear clipboard
|
// copy nodes and clear clipboard
|
||||||
if (this.canvas.selected_nodes) {
|
if (e.target.className === "litegraph" && this.canvas.selected_nodes) {
|
||||||
this.canvas.copyToClipboard();
|
this.canvas.copyToClipboard();
|
||||||
e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard
|
e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
@ -1297,7 +1308,13 @@ export class ComfyApp {
|
|||||||
|
|
||||||
let reset_invalid_values = false;
|
let reset_invalid_values = false;
|
||||||
if (!graphData) {
|
if (!graphData) {
|
||||||
graphData = structuredClone(defaultGraph);
|
if (typeof structuredClone === "undefined")
|
||||||
|
{
|
||||||
|
graphData = JSON.parse(JSON.stringify(defaultGraph));
|
||||||
|
}else
|
||||||
|
{
|
||||||
|
graphData = structuredClone(defaultGraph);
|
||||||
|
}
|
||||||
reset_invalid_values = true;
|
reset_invalid_values = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -577,6 +577,25 @@ export class ComfyUI {
|
|||||||
defaultValue: false,
|
defaultValue: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.settings.addSetting({
|
||||||
|
id: "Comfy.DisableFloatRounding",
|
||||||
|
name: "Disable rounding floats (requires page reload).",
|
||||||
|
type: "boolean",
|
||||||
|
defaultValue: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
this.settings.addSetting({
|
||||||
|
id: "Comfy.FloatRoundingPrecision",
|
||||||
|
name: "Decimal places [0 = auto] (requires page reload).",
|
||||||
|
type: "slider",
|
||||||
|
attrs: {
|
||||||
|
min: 0,
|
||||||
|
max: 6,
|
||||||
|
step: 1,
|
||||||
|
},
|
||||||
|
defaultValue: 0,
|
||||||
|
});
|
||||||
|
|
||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
|
|||||||
@ -1,18 +1,23 @@
|
|||||||
import { api } from "./api.js"
|
import { api } from "./api.js"
|
||||||
|
|
||||||
function getNumberDefaults(inputData, defaultStep) {
|
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||||
let defaultVal = inputData[1]["default"];
|
let defaultVal = inputData[1]["default"];
|
||||||
let { min, max, step } = inputData[1];
|
let { min, max, step, round} = inputData[1];
|
||||||
|
|
||||||
if (defaultVal == undefined) defaultVal = 0;
|
if (defaultVal == undefined) defaultVal = 0;
|
||||||
if (min == undefined) min = 0;
|
if (min == undefined) min = 0;
|
||||||
if (max == undefined) max = 2048;
|
if (max == undefined) max = 2048;
|
||||||
if (step == undefined) step = defaultStep;
|
if (step == undefined) step = defaultStep;
|
||||||
// precision is the number of decimal places to show.
|
// precision is the number of decimal places to show.
|
||||||
// by default, display the the smallest number of decimal places such that changes of size step are visible.
|
// by default, display the the smallest number of decimal places such that changes of size step are visible.
|
||||||
let precision = Math.max(-Math.floor(Math.log10(step)),0)
|
if (precision == undefined) {
|
||||||
// by default, round the value to those decimal places shown.
|
precision = Math.max(-Math.floor(Math.log10(step)),0);
|
||||||
let round = Math.round(1000000*Math.pow(0.1,precision))/1000000;
|
}
|
||||||
|
|
||||||
|
if (enable_rounding && (round == undefined || round === true)) {
|
||||||
|
// by default, round the value to those decimal places shown.
|
||||||
|
round = Math.round(1000000*Math.pow(0.1,precision))/1000000;
|
||||||
|
}
|
||||||
|
|
||||||
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
|
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
|
||||||
}
|
}
|
||||||
@ -268,15 +273,22 @@ export const ComfyWidgets = {
|
|||||||
"INT:noise_seed": seedWidget,
|
"INT:noise_seed": seedWidget,
|
||||||
FLOAT(node, inputName, inputData, app) {
|
FLOAT(node, inputName, inputData, app) {
|
||||||
let widgetType = isSlider(inputData[1]["display"], app);
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 0.5);
|
let precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision");
|
||||||
|
let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding")
|
||||||
|
if (precision == 0) precision = undefined;
|
||||||
|
const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding);
|
||||||
return { widget: node.addWidget(widgetType, inputName, val,
|
return { widget: node.addWidget(widgetType, inputName, val,
|
||||||
function (v) {
|
function (v) {
|
||||||
this.value = Math.round(v/config.round)*config.round;
|
if (config.round) {
|
||||||
|
this.value = Math.round(v/config.round)*config.round;
|
||||||
|
} else {
|
||||||
|
this.value = v;
|
||||||
|
}
|
||||||
}, config) };
|
}, config) };
|
||||||
},
|
},
|
||||||
INT(node, inputName, inputData, app) {
|
INT(node, inputName, inputData, app) {
|
||||||
let widgetType = isSlider(inputData[1]["display"], app);
|
let widgetType = isSlider(inputData[1]["display"], app);
|
||||||
const { val, config } = getNumberDefaults(inputData, 1);
|
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
|
||||||
Object.assign(config, { precision: 0 });
|
Object.assign(config, { precision: 0 });
|
||||||
return {
|
return {
|
||||||
widget: node.addWidget(
|
widget: node.addWidget(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user