mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
d2ed1dcb9a
@ -141,8 +141,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u
|
|||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
|
CublasOps = "cublas_ops"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
|||||||
@ -1422,3 +1422,101 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||||
|
'''
|
||||||
|
SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
|
||||||
|
Arxiv: https://arxiv.org/abs/2305.14267
|
||||||
|
'''
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = t_next - t
|
||||||
|
h_eta = h * (eta + 1)
|
||||||
|
s = t + r * h
|
||||||
|
fac = 1 / (2 * r)
|
||||||
|
sigma_s = s.neg().exp()
|
||||||
|
|
||||||
|
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
|
||||||
|
if inject_noise:
|
||||||
|
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
|
||||||
|
noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||||
|
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
|
||||||
|
|
||||||
|
# Step 1
|
||||||
|
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||||
|
if inject_noise:
|
||||||
|
x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
|
||||||
|
denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
|
||||||
|
|
||||||
|
# Step 2
|
||||||
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||||
|
x = (coeff_2 + 1) * x - coeff_2 * denoised_d
|
||||||
|
if inject_noise:
|
||||||
|
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
|
'''
|
||||||
|
SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
|
||||||
|
Arxiv: https://arxiv.org/abs/2305.14267
|
||||||
|
'''
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
inject_noise = eta > 0 and s_noise > 0
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
x = denoised
|
||||||
|
else:
|
||||||
|
t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
|
h = t_next - t
|
||||||
|
h_eta = h * (eta + 1)
|
||||||
|
s_1 = t + r_1 * h
|
||||||
|
s_2 = t + r_2 * h
|
||||||
|
sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
|
||||||
|
|
||||||
|
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
|
||||||
|
if inject_noise:
|
||||||
|
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
|
||||||
|
noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
|
||||||
|
noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
|
||||||
|
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
|
||||||
|
|
||||||
|
# Step 1
|
||||||
|
x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
|
||||||
|
if inject_noise:
|
||||||
|
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
|
||||||
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
|
# Step 2
|
||||||
|
x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
|
||||||
|
if inject_noise:
|
||||||
|
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
|
||||||
|
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
|
||||||
|
|
||||||
|
# Step 3
|
||||||
|
x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
|
||||||
|
if inject_noise:
|
||||||
|
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
|
||||||
|
return x
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.ops
|
import comfy.rmsnorm
|
||||||
|
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
|
|
||||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
rms_norm = comfy.rmsnorm.rms_norm
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
||||||
if weight is None:
|
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|||||||
48
comfy/ops.py
48
comfy/ops.py
@ -21,6 +21,7 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
import comfy.rmsnorm
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
@ -146,6 +147,25 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.bias = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if self.weight is not None:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -357,6 +377,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
|
|
||||||
return scaled_fp8_op
|
return scaled_fp8_op
|
||||||
|
|
||||||
|
CUBLAS_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from cublas_ops import CublasLinear
|
||||||
|
CUBLAS_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if CUBLAS_IS_AVAILABLE:
|
||||||
|
class cublas_ops(disable_weight_init):
|
||||||
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
@ -369,6 +408,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
):
|
):
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
|
if (
|
||||||
|
PerformanceFeature.CublasOps in args.fast and
|
||||||
|
CUBLAS_IS_AVAILABLE and
|
||||||
|
weight_dtype == torch.float16 and
|
||||||
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||||
|
):
|
||||||
|
logging.info("Using cublas ops")
|
||||||
|
return cublas_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
|
|||||||
65
comfy/rmsnorm.py
Normal file
65
comfy/rmsnorm.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
if RMSNorm is None:
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.learnable_scale = elementwise_affine
|
||||||
|
if self.learnable_scale:
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=None,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.weight, self.eps)
|
||||||
@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "er_sde"]
|
"gradient_estimation", "er_sde", "seeds_2", "seeds_3"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
56
comfy_extras/nodes_optimalsteps.py
Normal file
56
comfy_extras/nodes_optimalsteps.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# from https://github.com/bebebe666/OptimalSteps
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def loglinear_interp(t_steps, num_steps):
|
||||||
|
"""
|
||||||
|
Performs log-linear interpolation of a given array of decreasing numbers.
|
||||||
|
"""
|
||||||
|
xs = np.linspace(0, 1, len(t_steps))
|
||||||
|
ys = np.log(t_steps[::-1])
|
||||||
|
|
||||||
|
new_xs = np.linspace(0, 1, num_steps)
|
||||||
|
new_ys = np.interp(new_xs, xs, ys)
|
||||||
|
|
||||||
|
interped_ys = np.exp(new_ys)[::-1].copy()
|
||||||
|
return interped_ys
|
||||||
|
|
||||||
|
|
||||||
|
NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
|
||||||
|
"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
|
||||||
|
}
|
||||||
|
|
||||||
|
class OptimalStepsScheduler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{"model_type": (["FLUX", "Wan"], ),
|
||||||
|
"steps": ("INT", {"default": 20, "min": 3, "max": 1000}),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
|
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||||
|
|
||||||
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
|
def get_sigmas(self, model_type, steps, denoise):
|
||||||
|
total_steps = steps
|
||||||
|
if denoise < 1.0:
|
||||||
|
if denoise <= 0.0:
|
||||||
|
return (torch.FloatTensor([]),)
|
||||||
|
total_steps = round(steps * denoise)
|
||||||
|
|
||||||
|
sigmas = NOISE_LEVELS[model_type][:]
|
||||||
|
if (steps + 1) != len(sigmas):
|
||||||
|
sigmas = loglinear_interp(sigmas, steps + 1)
|
||||||
|
|
||||||
|
sigmas = sigmas[-(total_steps + 1):]
|
||||||
|
sigmas[-1] = 0
|
||||||
|
return (torch.FloatTensor(sigmas), )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"OptimalStepsScheduler": OptimalStepsScheduler,
|
||||||
|
}
|
||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.27"
|
__version__ = "0.3.28"
|
||||||
|
|||||||
@ -85,6 +85,7 @@ cache_helper = CacheHelper()
|
|||||||
|
|
||||||
extension_mimetypes_cache = {
|
extension_mimetypes_cache = {
|
||||||
"webp" : "image",
|
"webp" : "image",
|
||||||
|
"fbx" : "model",
|
||||||
}
|
}
|
||||||
|
|
||||||
def map_legacy(folder_name: str) -> str:
|
def map_legacy(folder_name: str) -> str:
|
||||||
@ -140,11 +141,14 @@ def get_directory_by_type(type_name: str) -> str | None:
|
|||||||
return get_input_directory()
|
return get_input_directory()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]:
|
def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
files = os.listdir(folder_paths.get_input_directory())
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
filter_files_content_types(files, ["image", "audio", "video"])
|
videos = filter_files_content_types(files, ["video"])
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- 'model' in MIME context refers to 3D models, not files containing trained weights and parameters
|
||||||
"""
|
"""
|
||||||
global extension_mimetypes_cache
|
global extension_mimetypes_cache
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
2
main.py
2
main.py
@ -10,6 +10,7 @@ from app.logger import setup_logger
|
|||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import comfyui_manager
|
import comfyui_manager
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -310,6 +311,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Running directly, just start ComfyUI.
|
# Running directly, just start ComfyUI.
|
||||||
|
logging.info("Python version: {}".format(sys.version))
|
||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -1655,6 +1655,7 @@ class LoadImage:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
|
files = folder_paths.filter_files_content_types(files, ["image"])
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(files), {"image_upload": True})},
|
{"image": (sorted(files), {"image_upload": True})},
|
||||||
}
|
}
|
||||||
@ -2284,6 +2285,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_hunyuan3d.py",
|
"nodes_hunyuan3d.py",
|
||||||
"nodes_primitive.py",
|
"nodes_primitive.py",
|
||||||
"nodes_cfg.py",
|
"nodes_cfg.py",
|
||||||
|
"nodes_optimalsteps.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.27"
|
version = "0.3.28"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from folder_paths import filter_files_content_types
|
from folder_paths import filter_files_content_types, extension_mimetypes_cache
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def file_extensions():
|
def file_extensions():
|
||||||
return {
|
return {
|
||||||
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
|
||||||
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
|
||||||
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
|
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'],
|
||||||
|
'model': ['gltf', 'glb', 'obj', 'fbx', 'stl']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +25,18 @@ def mock_dir(file_extensions):
|
|||||||
yield directory
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_correctly(mock_dir, file_extensions):
|
@pytest.fixture
|
||||||
|
def patched_mimetype_cache(file_extensions):
|
||||||
|
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
||||||
|
new_cache = extension_mimetypes_cache.copy()
|
||||||
|
for extension in file_extensions["model"]:
|
||||||
|
new_cache[extension] = "model"
|
||||||
|
|
||||||
|
with patch("folder_paths.extension_mimetypes_cache", new_cache):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
@ -30,7 +44,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions):
|
|||||||
assert f"sample_{content_type}.{extension}" in filtered_files
|
assert f"sample_{content_type}.{extension}" in filtered_files
|
||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_uniquely(mock_dir, file_extensions):
|
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user